# Copyright (c) 2021 Contrast Security, Inc. See https://www.contrastsecurity.com/enduser-terms-0317a for more details.
# frozen_string_literal: true

require 'set'

require 'contrast/agent/assess/policy/propagator'
require 'contrast/components/logger'
require 'contrast/utils/object_share'
require 'contrast/utils/sha256_builder'

module Contrast
  module Agent
    module Assess
      module Policy
        # This class is responsible for the continuation of traces. A Propagator is any method that transforms an
        # untrusted value. In general, these methods work on the String class or a holder of Strings.
        module PropagationMethod
          extend Contrast::Components::Logger::InstanceMethods

          APPEND_ACTION = 'APPEND'
          CENTER_ACTION = 'CENTER'
          INSERT_ACTION = 'INSERT'
          KEEP_ACTION = 'KEEP'
          NEXT_ACTION = 'NEXT'
          NOOP_ACTION = 'NOOP'
          PREPEND_ACTION = 'PREPEND'
          REPLACE_ACTION = 'REPLACE'
          REMOVE_ACTION = 'REMOVE'
          REVERSE_ACTION = 'REVERSE'
          SPLAT_ACTION = 'SPLAT'
          SPLIT_ACTION = 'SPLIT'
          DB_WRITE_ACTION = 'DB_WRITE'
          CUSTOM_ACTION = 'CUSTOM'

          class << self
            def determine_target propagation_node, ret, object, args
              target = propagation_node.targets[0]
              case target
              when Contrast::Utils::ObjectShare::OBJECT_KEY
                object
              when Contrast::Utils::ObjectShare::RETURN_KEY
                ret
              else
                args[target]
              end
            end

            # @param method_policy [Contrast::Agent::Patching::Policy::MethodPolicy] the policy that governs the
            #   patches to this method
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            # @return [Object, nil] the tracked Return or nil if no changes were made; will replace the return of the
            #   original function if not nil
            def apply_propagation method_policy, preshift, object, ret, args, block
              return unless method_policy.propagation_node
              return unless preshift

              propagation_node = method_policy.propagation_node

              target = determine_target(propagation_node, ret, object, args)
              PropagationMethod.apply_propagator(propagation_node, preshift, target, object, ret, args, block)
            end

            PROPAGATION_ACTIONS = {
                APPEND_ACTION => Contrast::Agent::Assess::Policy::Propagator::Append,
                CENTER_ACTION => Contrast::Agent::Assess::Policy::Propagator::Center,
                INSERT_ACTION => Contrast::Agent::Assess::Policy::Propagator::Insert,
                KEEP_ACTION => Contrast::Agent::Assess::Policy::Propagator::Keep,
                NEXT_ACTION => Contrast::Agent::Assess::Policy::Propagator::Next,
                NOOP_ACTION => nil,
                PREPEND_ACTION => Contrast::Agent::Assess::Policy::Propagator::Prepend,
                REPLACE_ACTION => Contrast::Agent::Assess::Policy::Propagator::Replace,
                REMOVE_ACTION => Contrast::Agent::Assess::Policy::Propagator::Remove,
                REVERSE_ACTION => Contrast::Agent::Assess::Policy::Propagator::Reverse,
                SPLAT_ACTION => Contrast::Agent::Assess::Policy::Propagator::Splat,
                SPLIT_ACTION => Contrast::Agent::Assess::Policy::Propagator::Split
            }.cs__freeze

            # I lied above. We had to figure out what the target of the propagation was. Now that we know, we'll
            # actually do things to it. Note that the return of this method will replace the original return of the
            # patched function unless it is nil, so be sure you're returning what you intend.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            # @return [Object, nil] the tracked Return or nil if no changes were made; will replace the return of the
            #   original function if not nil
            def apply_propagator propagation_node, preshift, target, object, ret, args, block
              return unless propagation_possible?(propagation_node, target)

              if propagation_node.action == DB_WRITE_ACTION
                Contrast::Agent::Assess::Policy::Propagator::DatabaseWrite.propagate(propagation_node, preshift, ret)
              elsif propagation_node.action == CUSTOM_ACTION
                Contrast::Agent::Assess::Policy::Propagator::Custom.propagate(propagation_node, preshift, ret, block)
              elsif propagation_node.action == SPLIT_ACTION
                Contrast::Agent::Assess::Policy::Propagator::Split.propagate(propagation_node, preshift, target)
              elsif Contrast::Utils::DuckUtils.iterable_hash?(target)
                handle_hash_propagation(propagation_node, preshift, target, object, ret, args, block)
              elsif Contrast::Utils::DuckUtils.iterable_enumerable?(target)
                handle_enumerable_propagation(propagation_node, preshift, target, object, ret, args, block)
              else
                handle_cs_properties_propagation(propagation_node, preshift, target, object, ret, args, block)
              end
            rescue StandardError => e
              logger.warn('Unable to apply propagation', e, node_id: propagation_node.id)
              nil
            end

            # Custom actions tend to be the more complex of our propagations. Often, the method has to make decisions
            # about the target based on the context with which the method was called. As such, defer determining if the
            # target is valid to that method.
            #
            # In all other cases, a target is valid for propagation if it is not nil
            #
            # @param target [Object] the thing to which to propagate
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @return [Boolean]
            def valid_target? target, propagation_node
              return true if propagation_node.action == CUSTOM_ACTION

              !!target
            end

            ZERO_LENGTH_ACTIONS = [DB_WRITE_ACTION, CUSTOM_ACTION, KEEP_ACTION, REPLACE_ACTION, SPLAT_ACTION].cs__freeze
            # If the action required needs a length and the target does not have one, the length is not valid
            #
            # @param target [Object] the thing to which to propagate
            # @param action [String] the name of the action taken during this propagation
            # @return [Boolean]
            def valid_length? target, action
              return true if ZERO_LENGTH_ACTIONS.include?(action)

              if Contrast::Utils::DuckUtils.quacks_to?(target, :length)
                target.length != 0 # rubocop:disable Style/ZeroLengthPredicate
              else
                !target.to_s.empty?
              end
            end

            # Before we do any work, we should check if we even need to. If the source and target of this patcher are
            # not tracked, there's no need to do anything. A copy of nothing is still nothing.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the thing to which to propagate
            # @return [Boolean]
            def can_propagate? propagation_node, preshift, target
              return false unless appropriate_target?(propagation_node, target)
              return true if Contrast::Utils::Assess::TrackingUtil.tracked?(target)
              return false unless preshift

              propagation_node.sources.each do |source|
                case source
                when Contrast::Utils::ObjectShare::OBJECT_KEY
                  return true if Contrast::Utils::Assess::TrackingUtil.tracked?(preshift.object)
                else
                  # has to be P, there's no ret source type (yet? ever?)
                  return true if preshift.args && Contrast::Utils::Assess::TrackingUtil.tracked?(preshift.args[source])
                end
              end
              false
            end

            # We cannot propagate to frozen things that have not been updated to work with our property tracking,
            # unless they're duplicable and the return. We probably shouldn't propagate to frozen things at all, as
            # they're supposed to be immutable, but third parties do jenky things, so allow it as long as it is safe to
            # do.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            # @return [Boolean] if the target can be propagated to
            def appropriate_target? propagation_node, target
              # special handle Returns b/c we can do unfreezing magic during propagation
              return true if propagation_node.targets[0] == Contrast::Utils::ObjectShare::RETURN_KEY

              Contrast::Agent::Assess::Tracker.trackable?(target)
            end

            # If this patcher has tags, apply them to the entire target
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            def apply_tags propagation_node, target
              return unless propagation_node.tags
              return unless (properties = Contrast::Agent::Assess::Tracker.properties(target))

              length = Contrast::Utils::StringUtils.ret_length(target)
              propagation_node.tags.each do |tag|
                properties.add_tag(tag, 0...length)
              end
            end

            # If this patcher has tags, remove them from the entire target
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            def apply_untags propagation_node, target
              return unless propagation_node.untags
              return unless (properties = Contrast::Agent::Assess::Tracker.properties(target))

              propagation_node.untags.each do |tag|
                properties.delete_tags(tag)
              end
            end

            private

            # This is checked right before actual propagation
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param target [Object] the Target to which to propagate.
            # @return [Boolean]
            def propagation_possible? propagation_node, target
              return false unless propagation_node && valid_target?(target, propagation_node)
              return false unless valid_length?(target, propagation_node.action)

              true
            end

            # Safely duplicate the target, or return nil
            #
            # @param target [Object] the thing to check for duplication
            # @return [Object, nil]
            def safe_dup target
              target.dup
            rescue StandardError => _e
              nil
            end

            # Iterate over each key and value in a hash to allow for propagation to each.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            def handle_hash_propagation propagation_node, preshift, target, object, ret, args, block
              target.each_pair do |key, value|
                apply_propagator(propagation_node, preshift, key, object, ret, args, block)
                apply_propagator(propagation_node, preshift, value, object, ret, args, block)
              end
            end

            # Iterate over each value in an enumerable to allow for propagation to each.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param block [Block] the Block passed to the original method
            def handle_enumerable_propagation propagation_node, preshift, target, object, ret, args, block
              target.each do |value|
                next if target == value

                apply_propagator(propagation_node, preshift, value, object, ret, args, block)
              end
            end

            # Move the properties from the source(s) to the target of the propagation.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs this
            #   propagation event.
            # @param preshift [Contrast::Agent::Assess::PreShift] The capture of the state of the code just prior to
            #   the invocation of the patched method.
            # @param target [Object] the Target to which to propagate.
            # @param object [Object] the Object on which the method was invoked
            # @param ret [Object] the Return of the invoked method
            # @param args [Array<Object>] the Arguments with which the method was invoked
            # @param _block [Block] the Block passed to the original method
            def handle_cs_properties_propagation propagation_node, preshift, target, object, ret, args, _block
              return if propagation_node.action == NOOP_ACTION
              return unless can_propagate?(propagation_node, preshift, target)
              return unless (propagation_class = find_propagation_class(propagation_node))

              restore_frozen_state = false
              if target.cs__frozen? && !Contrast::Agent::Assess::Tracker.trackable?(target)
                return unless can_handle_frozen?(propagation_node)
                return unless (dup = safe_dup(ret))

                restore_frozen_state = true
                ret = dup
                target = ret
                Contrast::Agent::Assess::Tracker.pre_freeze(ret)
                ret.cs__freeze
                # double check that we were able to finalize the replaced return
                return unless Contrast::Agent::Assess::Tracker.trackable?(target)
              end
              propagation_class.propagate(propagation_node, preshift, target)
              # Once we've propagated, attempt to tag the target if there is a tag(s) to be applied
              apply_tags(propagation_node, target)
              # Even though we skipped propagating tags from the source if they were included in untags, the target may
              # have already had some on it. Let's go ahead and remove them. In this order, untags takes precedent over
              # tags; but we control both and there should never be a propagator that has a tag in its untag.
              apply_untags(propagation_node, target)
              return unless (properties = Contrast::Agent::Assess::Tracker.properties!(target))

              properties.add_properties(propagation_node.properties)
              properties.build_event(propagation_node, target, object, ret, args)
              logger.trace('Propagation detected', node_id: propagation_node.id, target_id: target.__id__)
              restore_frozen_state ? ret : nil
            end

            # Find the propagation class from the given node, if one exists.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs a
            #   propagation event.
            # @return [Contrast::Agent::Assess::Policy::Propagator, nil]
            def find_propagation_class propagation_node
              unless (propagation_class = PROPAGATION_ACTIONS.fetch(propagation_node.action, nil))
                logger.warn('Unknown propagation action received. Unable to propagate.',
                            node_id: propagation_node.id,
                            action: propagation_node.action)
              end
              propagation_class
            end

            # We can handle frozen propagation iff we're allowed to, as determined by configuration, and the target of
            # the propagation is a return, as that's a replaceable value.
            #
            # @param propagation_node [Contrast::Agent::Assess::Policy::PropagationNode] the node that governs a
            #   propagation event.
            # @return [Boolean]
            def can_handle_frozen? propagation_node
              ::Contrast::ASSESS.track_frozen_sources? &&
                  propagation_node.targets[0] == Contrast::Utils::ObjectShare::RETURN_KEY
            end
          end
        end
      end
    end
  end
end