# Copyright (c) 2023 Contrast Security, Inc. See https://www.contrastsecurity.com/enduser-terms-0317a for more details.
# frozen_string_literal: true

require 'set'

require 'contrast/agent/assess/policy/propagator'
require 'contrast/components/logger'
require 'contrast/utils/object_share'
require 'contrast/utils/sha256_builder'
require 'contrast/utils/assess/propagation_method_utils'
require 'contrast/utils/assess/event_limit_utils'
require 'contrast/agent/assess/events/event_data'
require 'contrast/utils/assess/object_store'

module Contrast
  module Agent
    module Assess
      module Policy
        # This class is responsible for the continuation of traces. A Propagator is any method that transforms an
        # untrusted value. In general, these methods work on the String class or a holder of Strings.
        # rubocop:disable Metrics/ModuleLength
        module PropagationMethod
          extend Contrast::Components::Logger::InstanceMethods
          extend Contrast::Utils::Assess::PropagationMethodUtils
          extend Contrast::Utils::Assess::EventLimitUtils

          @properties = Contrast::Utils::Assess::ObjectStore.new

          class << self
            # @param method_policy [Contrast::Agent::Patching::Policy::MethodPolicy] the policy that governs the
            #   patches to this method
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            # @return [Object, nil] the tracked Return or nil if no changes were made; will replace the return of the
            #   original function if not nil
            def apply_propagation method_policy, preshift, object, ret, args, block
              return unless (propagation_node = method_policy.propagation_node)
              return unless propagation_node.use_original_object? || preshift
              return if event_limit?(method_policy)

              target = determine_target(propagation_node, ret, object, args)
              propagation_data = Contrast::Agent::Assess::Events::EventData.new(nil, nil, object, ret, args)
              PropagationMethod.apply_propagator(propagation_node, preshift, target, propagation_data, block)
            end

            # I lied above. We had to figure out what the target of the propagation was. Now that we know, we'll
            # actually do things to it. Note that the return of this method will replace the original return of the
            # patched function unless it is nil, so be sure you're returning what you intend.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param propagation_data [Contrast::Agent::Assess::Events::EventData] this will hold the
            #                         object [Object] the Object on which the method was invoked
            #                         ret [Object] the Return of the invoked method
            #                         args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            def apply_propagator propagation_node, preshift, target, propagation_data, block
              return unless propagation_possible?(propagation_node, target)

              if propagation_node.action == DB_WRITE_ACTION
                Contrast::Agent::Assess::Policy::Propagator::DatabaseWrite.propagate(propagation_node,
                                                                                     preshift,
                                                                                     propagation_data.ret)
              elsif propagation_node.action == CUSTOM_ACTION
                Contrast::Agent::Assess::Policy::Propagator::Custom.propagate(propagation_node,
                                                                              preshift,
                                                                              propagation_data.ret,
                                                                              block)
              elsif propagation_node.action == SPLIT_ACTION
                Contrast::Agent::Assess::Policy::Propagator::Split.propagate(propagation_node, preshift, target)
              elsif Contrast::Utils::DuckUtils.iterable_hash?(target)
                handle_hash_propagation(propagation_node, preshift, target, propagation_data, block)
              elsif Contrast::Utils::DuckUtils.iterable_enumerable?(target)
                handle_enumerable_propagation(propagation_node, preshift, target, propagation_data, block)
              else
                handle_cs_properties_propagation(propagation_node, preshift, target, propagation_data, block)
              end
            rescue StandardError => e
              logger.warn('Unable to apply propagation', e, node_id: propagation_node.id)
              nil
            end

            # If this patcher has tags, apply them to the entire target
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            def apply_tags propagation_node, target
              return unless (propagation_tags = propagation_node.tags)
              return unless (properties = Contrast::Agent::Assess::Tracker.properties(target))

              length = Contrast::Utils::StringUtils.ret_length(target)
              propagation_tags.each do |tag|
                properties.add_tag(tag, 0...length)
              end
            end

            def context_available?
              !!Contrast::Agent::REQUEST_TRACKER.current
            end

            # If this patcher has tags, remove them from the entire target
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            def apply_untags propagation_node, target
              return unless propagation_node.untags
              return unless (properties = Contrast::Agent::Assess::Tracker.properties(target))

              propagation_node.untags.each do |tag|
                properties.delete_tags(tag)
              end
            end

            private

            # This is checked right before actual propagation
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            # @return [Boolean]
            def propagation_possible? propagation_node, target
              return false unless propagation_node && valid_target?(target, propagation_node)
              return false unless context_available? || Contrast::ASSESS.non_request_tracking?
              return false unless valid_length?(target, propagation_node.action)

              true
            end

            # Iterate over each key and value in a hash to allow for propagation to each.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param propagation_data [Contrast::Agent::Assess::Events::EventData] this will hold the
            #                         object [Object] the Object on which the method was invoked
            #                         ret [Object] the Return of the invoked method
            #                         args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            def handle_hash_propagation propagation_node, preshift, target, propagation_data, block
              target.each_pair do |key, value|
                apply_propagator(propagation_node, preshift, key, propagation_data, block)
                apply_propagator(propagation_node, preshift, value, propagation_data, block)
              end
            end

            # Iterate over each value in an enumerable to allow for propagation to each.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param propagation_data [Contrast::Agent::Assess::Events::EventData] this will hold the
            #                         object [Object] the Object on which the method was invoked
            #                         ret [Object] the Return of the invoked method
            #                         args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            def handle_enumerable_propagation propagation_node, preshift, target, propagation_data, block
              target.each do |value|
                next if target == value

                apply_propagator(propagation_node, preshift, value, propagation_data, block)
              end
            end

            # Move the properties from the source(s) to the target of the propagation.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param propagation_data [Contrast::Agent::Assess::Events::EventData] this will hold the
            #                         object [Object] the Object on which the method was invoked
            #                         ret [Object] the Return of the invoked method
            #                         args [Array<Object>] the Arguments with which the method was invoked
            # @param _block [Block] the Block passed to the original method
            def handle_cs_properties_propagation propagation_node, preshift, target, propagation_data, _block
              return if propagation_node.action == NOOP_ACTION
              return unless can_propagate?(propagation_node, preshift, target, propagation_data)
              return unless (propagation_class = find_propagation_class(propagation_node))

              # If we are using the original object tracking, the preshift object is not created.
              # Instead identify the source as the original object itself and propagate with it.
              source = propagation_node.use_original_object? ? propagation_data.object : preshift
              handle_propagation(propagation_class, propagation_node, source, target)
              update_properties(propagation_node, target, propagation_data)
              increment_event_count(propagation_node)
            end

            def handle_propagation propagation_class, propagation_node, source, target
              if propagation_node.patch_method
                propagation_class.send(propagation_node.patch_method, propagation_node, source, target)
              else
                propagation_class.propagate(propagation_node, source, target)
              end
            end

            def update_properties propagation_node, target, propagation_data
              if propagation_node.use_original_on_bang_method?
                properties = use_original_object_properties(propagation_data)

                return unless properties
              else
                # Once we've propagated, attempt to tag the target if there is a tag(s) to be applied
                apply_tags(propagation_node, target)
                # Even though we skipped propagating tags from the source if they were included in untags, the target
                # may have already had some on it. Let's go ahead and remove them. In this order, untags takes
                # precedent over tags; but we control both and there should never be a propagator that has a
                # tag in its untag.
                apply_untags(propagation_node, target)
                return unless (properties = Contrast::Agent::Assess::Tracker.properties!(target))

                properties.add_properties(propagation_node.properties)
              end
              event_data = Contrast::Agent::Assess::Events::EventData.new(propagation_node,
                                                                          target,
                                                                          propagation_data.object,
                                                                          propagation_data.ret,
                                                                          propagation_data.args)
              properties.build_event(event_data)
              logger.trace('Propagation detected', node_id: propagation_node.id, target_id: target.__id__)
            end

            # Find the propagation class from the given node, if one exists.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs a
            #   propagation event.
            # @return [Contrast::Agent::Assess::Policy::Propagator, nil]
            def find_propagation_class propagation_node
              unless (propagation_class = PROPAGATION_ACTIONS.fetch(propagation_node.action, nil))
                logger.warn('Unknown propagation action received. Unable to propagate.',
                            node_id: propagation_node.id,
                            action: propagation_node.action)
              end
              propagation_class
            end

            # For certain bang methods we return the same object, no need to create expensive new properties.
            #
            # @param propagation_data [Contrast::Agent::Assess::Events::EventData] used to hold object, args
            # and ret.
            # @return properties [Contrast::Agent::Assess::Properties] the original properties transfered to
            # target.
            def use_original_object_properties propagation_data
              unless @properties[propagation_data.object.__id__]
                @properties[propagation_data.object.__id__] = Contrast::Agent::Assess::Tracker.properties!(
                    propagation_data.object)
              end
              @properties[propagation_data.object.__id__]
            end
          end
        end
        # rubocop:enable Metrics/ModuleLength
      end
    end
  end
end