# Copyright (c) 2021 Contrast Security, Inc. See https://www.contrastsecurity.com/enduser-terms-0317a for more details.
# frozen_string_literal: true

require 'contrast/components/interface'
require 'contrast/extension/module'
require 'contrast/utils/class_util'

module Contrast
  module Agent
    module Patching
      module Policy
        # Used to handle tracking patches that need to apply special instrumentation when a module is loaded
        class AfterLoadPatch
          include Contrast::Components::Interface
          access_component :scope
          attr_reader :applied, :module_name, :instrumentation_file_path, :method_to_instrument, :instrumenting_module

          def initialize module_name, instrumentation_file_path, method_to_instrument: nil, instrumenting_module: nil
            @applied = false
            @module_name = module_name
            @method_to_instrument = method_to_instrument
            @instrumentation_file_path = instrumentation_file_path
            @instrumenting_module = instrumenting_module
          end

          def applied?
            applied
          end

          # Modules can be re-opened, so the first load may not
          # necessarily define the method we're looking for:
          #
          # patching MyMod#instrumentable:
          #
          # file1:
          #   module MyMod
          #     def unrelated        <-- false lead
          #     end
          #   end
          #
          # file2:
          #   module MyMod
          #     def instrumentable   <-- actual target
          #     end
          #   end
          def blocked_by_method?
            return true  unless target_defined? # bc no methods are loaded
            return false unless method_to_instrument

            !module_lookup.instance_methods.include? method_to_instrument
          end

          def applies? loaded_module_name
            (loaded_module_name == module_name) && !blocked_by_method?
          end

          def target_defined?
            Contrast::Utils::ClassUtil.truly_defined?(module_name)
          end

          def instrument!
            require instrumentation_file_path
            if instrumenting_module
              mod = Module.cs__const_get(instrumenting_module)
              with_contrast_scope { mod.instrument } if mod
            end
            @applied = true
          end

          private

          def module_lookup
            @_module_lookup ||= begin
              Module.cs__const_get module_name
            rescue StandardError => _e
              nil
            end
          end
        end
      end
    end
  end
end