# Copyright (c) 2022 Contrast Security, Inc. See https://www.contrastsecurity.com/enduser-terms-0317a for more details.
# frozen_string_literal: true

module Contrast
  module Agent
    module Assess
      module Policy
        module Propagator
          # This class is specifically for String#select propagation
          #
          # Disclaimer: there may be a better way, but we're
          # in a 'get it work' state. hopefully, we'll be in
          # a 'get it right' state soon.
          class Select
            class << self
              def select_tagger patcher, preshift, ret, _block
                source = preshift.object
                args = preshift.args

                # 'gotcha'
                # Additionally, an empty string is returned when the starting index for
                # a character range is at the end of the string. Let's just skip that
                # and only track a string that has length
                return unless ret && !ret.empty? && Contrast::Agent::Assess::Tracker.tracked?(source)

                return unless (source_properties = Contrast::Agent::Assess::Tracker.properties(source))
                return unless (properties = Contrast::Agent::Assess::Tracker.properties!(ret))

                event_data = Contrast::Agent::Assess::Events::EventData.new(patcher, ret, source, ret, args)
                properties.build_event(event_data)

                range = determine_select_range(source, args)
                return unless range

                tags = source_properties.tags_at_range(range)
                properties.clear_tags
                tags.each_pair do |key, value|
                  properties.set_tags(key, value)
                end
                ret
              end

              private

              def handle_integer args, arg, source
                length = args[1] || 1
                # (void) negative range
                arg += source.length if arg.negative?
                arg...(arg + length)
              end

              def handle_string arg, source
                idx = source.index(arg)
                idx...(idx + arg.length)
              end

              def handle_regexp args, arg, source
                match_data = arg.match(source)
                # nil has the same meaning as 0. use full match
                group = args[1] || 0
                match_data.begin(group)...match_data.end(group)
              end

              def handle_range arg, source
                start = arg.begin
                finish = arg.end

                # (void) negative range
                start += source.length if start.negative?
                finish += source.length if finish.negative?
                finish += 1 unless arg.exclude_end?

                start...finish
              end

              def determine_select_range source, args
                arg = args[0]
                case arg
                when Integer
                  handle_integer(args, arg, source)
                when String
                  handle_string(arg, source)
                when Regexp
                  handle_regexp(args, arg, source)
                when Range
                  handle_range(arg, source)
                end
              end
            end
          end
        end
      end
    end
  end
end