# Copyright (c) 2021 Contrast Security, Inc. See https://www.contrastsecurity.com/enduser-terms-0317a for more details.
# frozen_string_literal: true

require 'set'
require 'contrast/agent/assess/policy/source_validation/source_validation'
require 'contrast/components/interface'
require 'contrast/utils/object_share'
require 'contrast/utils/sha256_builder'

module Contrast
  module Agent
    module Assess
      module Policy
        # This class controls the actions we take on Sources, as determined by our Assess policy. It indicates what
        # actions we should take in order to mark data as User Input and treat it as untrusted, starting the dataflows
        # used in Assess vulnerability detection.
        module SourceMethod
          include Contrast::Components::Interface
          access_component :analysis, :logging

          PARAMETER_TYPE     = 'PARAMETER'
          PARAMETER_KEY_TYPE = 'PARAMETER_KEY'
          HEADER_TYPE        = 'HEADER'
          HEADER_KEY_TYPE    = 'HEADER_KEY'
          COOKIE_TYPE        = 'COOKIE'
          COOKIE_KEY_TYPE    = 'COOKIE_KEY'

          class << self
            # This is called from within our woven proc. It will be called as if it were inline in the Rack
            # application.
            #
            # @param method_policy [Contrast::Agent::Patching::Policy::MethodPolicy] the policy that applies to the
            #   method being called
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @return [Object, nil] the tracked Return or nil if no changes were made
            def source_patchers method_policy, object, ret, args
              return unless analyze?(method_policy, object, ret, args)

              source_node = method_policy.source_node
              target = determine_target(source_node, object, ret, args)
              restore_frozen_state = false
              if target.cs__frozen? && !Contrast::Agent::Assess::Tracker.trackable?(target)
                return unless ASSESS.track_frozen_sources?
                return unless source_node.targets[0] == Contrast::Utils::ObjectShare::RETURN_KEY

                dup = safe_dup(ret)
                return unless dup

                restore_frozen_state = true
                ret = dup
                target = ret
                Contrast::Agent::Assess::Tracker.pre_freeze(ret)
                ret.cs__freeze
                # double check that we were able to finalize the replaced return
                return unless Contrast::Agent::Assess::Tracker.trackable?(target)
              end
              apply_source(Contrast::Agent::REQUEST_TRACKER.current, source_node, target, object, ret, source_node.type, nil, *args)
              restore_frozen_state ? ret : nil
            end

            private

            # This is our method that actually taints the object our source_node targets.
            #
            # @param context [Contrast::Utils::ThreadTracker] the current request context
            # @param source_node [Contrast::Agent::Assess::Policy::SourceNode] the node to direct applying this source
            #   event
            # @param target [Object] the target of the Source Event
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param source_type [String] the type of this source, from the source_node, or a KEY_TYPE if invoked for a
            #   map
            # @param source_name [String, nil] the name of this source, i.e. the key used to accessed if from a map or
            #   nil if a type like BODY
            # @param args [Array<Object>] the Arguments with which the method was invoked
            def apply_source context, source_node, target, object, ret, source_type, source_name = nil, *args
              return unless context && source_node && target

              source_name ||= determine_source_name(source_node, object, ret, *args)
              # We know we only work on certain things.
              # Skip if this isn't one of them
              if Contrast::Agent::Assess::Tracker.trackable?(target)
                apply_tags(source_node, target, object, ret, source_type, source_name, *args)
              elsif Contrast::Utils::DuckUtils.iterable_hash?(target)
                apply_hash_tags(context, source_node, target, object, ret, source_type, *args)
                # While we don't taint arrays themselves, we may taint the things they hold. Let's pass their keys and
                # values back to ourselves and try again
              elsif Contrast::Utils::DuckUtils.iterable_enumerable?(target)
                target.each { |value| apply_source(context, source_node, value, object, ret, source_type, source_name, *args) }
              end
            rescue StandardError => e
              logger.warn('Unable to apply source', e, node_id: source_node.id)
            end

            # While we don't taint hashes themselves, we may taint the things they hold. Let's pass their keys and
            # values back to ourselves and try again
            #
            # @param context [Contrast::Utils::ThreadTracker] the current request context
            # @param source_node [Contrast::Agent::Assess::Policy::SourceNode] the node to direct applying this source
            #   event
            # @param target [Object] the target of the Source Event
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param source_type [String] the type of this source, from the source_node, or a KEY_TYPE if invoked for a
            #   map
            # @param args [Array<Object>] the Arguments with which the method was invoked
            def apply_hash_tags context, source_node, target, object, ret, source_type, *args
              to_replace = []
              target.each_pair do |key, value|
                # We only do this for Strings b/c of the way Hash lookup works. To replace another object would break
                # hash lookup and, therefore, the application
                if replace_hash_key?(key, target)
                  key = key.dup
                  to_replace << key
                end
                apply_source(context, source_node, key, object, ret, key_type(source_type), key, *args)
                apply_source(context, source_node, value, object, ret, source_type, key, *args)
              end
              handle_hash_key(target, to_replace)
            end

            # Given an unfrozen hash, if the key is a String, we should replace it with one that we can finalize,
            # allowing us to track that key. This method handles checking if that replace can and should occur.
            #
            # @param key [Object] the key in the hash that may need replacing.
            # @param hash [Hash] the hash to which the key belongs.
            # @return [Boolean] whether replace the key in the hash or not.
            def replace_hash_key? key, hash
              ASSESS.track_frozen_sources? &&
                  !hash.cs__frozen? &&
                  key.is_a?(String) &&
                  !Contrast::Agent::Assess::Tracker.trackable?(key)
            end

            # Safely duplicate the target, or return nil
            #
            # @param target [Object] the thing to check for duplication
            def safe_dup target
              target.dup
            rescue StandardError => _e
              nil
            end

            # Hash is designed to keep one instance of the string key in it. We need to remove the existing one and
            # replace it with our new tracked one.
            def handle_hash_key target, to_replace
              to_replace.each do |key|
                Contrast::Agent::Assess::Tracker.pre_freeze(key)
                key.cs__freeze
                value = target.delete(key)
                target[key] = value
              end
            end

            def apply_tags source_node, target, object, ret, source_type, source_name, *args
              # don't apply tags if we can't track the thing
              return unless Contrast::Agent::Assess::Tracker.trackable?(target)
              # don't apply second source -- probably needs tuning later if we use more than 'UNTRUSTED' in our sources
              return if Contrast::Agent::Assess::Tracker.tracked?(target)
              return unless (properties = Contrast::Agent::Assess::Tracker.properties!(target))

              # otherwise for each tag this source_node applies, create a tag range on the target object. I realize
              # this looping is counter-intuitive from the above message, that's why we're revisiting.
              source_node.tags.each do |tag|
                next unless Contrast::Agent::Assess::Policy::SourceValidation.valid?(tag, source_type, source_name)

                length = Contrast::Utils::StringUtils.ret_length(target)
                properties.add_tag(tag, 0...length)
                properties.add_properties(source_node.properties)
                logger.trace('Source detected',
                             node_id: source_node.id,
                             target_id: target.__id__,
                             tag: tag)
              end
              # make a representation of this method that TeamServer can render
              properties.build_event(source_node, target, object, ret, args, source_type, source_name)
            end

            # Find the name of the source
            #
            # @param source_node [Contrast::Agent::Assess::Policy::SourceNode] the node to direct applying this source
            #   event
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @return [String, nil] the human readable name of the target to which this source event applies, or nil if
            #   none provided by the node
            def determine_source_name source_node, object, ret, *args
              return source_node.get_property('dynamic_source_name') if source_node.type == 'UNTRUSTED_DATABASE'

              source_node_source = source_node.sources[0]
              case source_node_source
              when nil
                nil
              when Contrast::Utils::ObjectShare::RETURN_KEY
                ret
              when Contrast::Utils::ObjectShare::OBJECT_KEY
                object
              else
                args[source_node_source]
              end
            end

            # Determine if we should analyze this method invocation for a Source or not. We should if we have enough
            # information to build the context of this invocation, we're not disabled, and we can't immediately
            # determine the invocation was done safely.
            #
            # @param method_policy [Contrast::Agent::Patching::Policy::MethodPolicy] the policy that applies to the
            #   method being called
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @return [boolean] if the invocation of this method should be analyzed
            def analyze? method_policy, object, ret, args
              return false unless method_policy&.source_node
              return false unless ASSESS.enabled?
              return false unless Contrast::Agent::REQUEST_TRACKER.current&.analyze_request?

              !safe_invocation?(method_policy.source_node, object, ret, args)
            end

            # Determine if the method was invoked safely.
            #
            # @param source_node [Contrast::Agent::Assess::Policy::SourceNode] the node to direct applying this source
            #   event
            # @param _object [Object] the Object on which the method was invoked
            # @param _ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @return [boolean] if the invocation of this method was safe
            def safe_invocation? source_node, _object, _ret, args
              # According the the Rack Specification https://github.com/rack/rack/blob/master/SPEC.rdoc, any header
              # from the Request will start with HTTP_. As such, only Headers with that key should be considered for
              # tracking, as the others have come from the Framework or Middleware stashing in the ENV. Rails, for
              # instance, uses action_dispatch. to store several values. Technically, you can't call
              # Rack::Request#get_header without a parameter, and that parameter should be a String, but trust no one.
              source_node.id == 'Assess:Source:Rack::Request::Env#get_header' &&
                  args&.any? &&
                  !args[0].to_s.start_with?('HTTP_')
            end

            # Find the literal target of the propagation
            #
            # @param source_node [Contrast::Agent::Assess::Policy::SourceNode] the node to direct applying this source
            #   event
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @return [Object] the target to which this source event applies
            def determine_target source_node, object, ret, args
              source_target = source_node.targets[0]
              case source_target
              when Contrast::Utils::ObjectShare::RETURN_KEY
                ret
              when Contrast::Utils::ObjectShare::OBJECT_KEY
                object
              else
                args[source_target]
              end
            end

            # Simple helper method to flip the type from value to key when the source is the key of a Hash
            #
            # @param source_type [String] the original value source type
            # @return [String] the key form of the source type, if one exists, else the original source type
            def key_type source_type
              case source_type
              when PARAMETER_TYPE
                PARAMETER_KEY_TYPE
              when HEADER_TYPE
                HEADER_KEY_TYPE
              when COOKIE_TYPE
                COOKIE_KEY_TYPE
              else
                source_type
              end
            end
          end
        end
      end
    end
  end
end