module Prawn::SVG::CSS
  class Stylesheets
    USER_AGENT_STYLESHEET = 'svg, symbol, image, marker, pattern, foreignObject { overflow: hidden }'.freeze

    attr_reader :css_parser, :root, :media

    def initialize(css_parser, root, media = :all)
      @css_parser = css_parser
      @root = root
      @media = media
    end

    def load
      load_user_agent_stylesheet
      load_style_elements
      xpath_styles = gather_xpath_styles
      associate_xpath_styles_with_elements(xpath_styles)
    end

    private

    def load_user_agent_stylesheet
      css_parser.add_block!(USER_AGENT_STYLESHEET)
    end

    def load_style_elements
      REXML::XPath.match(root, '//style').each do |source|
        data = source.texts.map(&:value).join
        css_parser.add_block!(data)
      end
    end

    def gather_xpath_styles
      xpath_styles = []
      order = 0

      css_parser.each_rule_set(media) do |rule_set, _|
        declarations = []
        rule_set.each_declaration { |*data| declarations << data }

        rule_set.selectors.each do |selector_text|
          next unless (selector = Prawn::SVG::CSS::SelectorParser.parse(selector_text))

          xpath = css_selector_to_xpath(selector)
          specificity = calculate_specificity(selector)
          specificity << order
          order += 1

          xpath_styles << [xpath, declarations, specificity]
        end
      end

      xpath_styles.sort_by(&:last)
    end

    def associate_xpath_styles_with_elements(xpath_styles)
      element_styles = {}

      xpath_styles.each do |xpath, declarations, _|
        REXML::XPath.match(root, xpath).each do |element|
          (element_styles[element] ||= []).concat declarations
        end
      end

      element_styles
    end

    def xpath_quote(value)
      %("#{value.gsub('\\', '\\\\').gsub('"', '\\"')}") if value
    end

    def css_selector_to_xpath(selector)
      selector.map do |element|
        pseudo_classes = Set.new(element[:pseudo_class])
        require_function_name = false

        result = case element[:combinator]
                 when :child
                   +'/'
                 when :adjacent
                   pseudo_classes << 'first-child'
                   +'/following-sibling::'
                 when :siblings
                   +'/following-sibling::'
                 else
                   +'//'
                 end

        positions = []
        pseudo_classes.each do |pc|
          case pc
          when 'first-child' then positions << '1'
          when 'last-child'  then positions << 'last()'
          when /^nth-child\((\d+)\)$/ then positions << $1
          end
        end

        unless positions.empty?
          result << '*' unless require_function_name
          require_function_name = true

          logic = if positions.length == 1
                    positions.first
                  else
                    positions.map { |position| "position()=#{position}" }.join(' and ')
                  end

          result << "[#{logic}]"
        end

        if require_function_name
          result << "[name()=#{xpath_quote element[:name]}]" if element[:name]
        else
          result << (element[:name] || '*')
        end

        result << ((element[:class] || []).map { |name| "[contains(concat(' ',@class,' '), ' #{name} ')]" }.join)
        result << ((element[:id] || []).map { |name| "[@id='#{name}']" }.join)

        (element[:attribute] || []).each do |key, operator, value|
          case operator
          when nil
            result << "[@#{key}]"
          when '='
            result << "[@#{key}=#{xpath_quote value}]"
          when '^='
            result << "[starts-with(@#{key}, #{xpath_quote value})]"
          when '$='
            result << "[substring(@#{key}, string-length(@#{key}) - #{value.length - 1}) = #{xpath_quote value})]"
          when '*='
            result << "[contains(@#{key}, #{xpath_quote value})]"
          when '~='
            result << "[contains(concat(' ',@#{key},' '), #{xpath_quote " #{value} "})]"
          when '|='
            result << "[contains(concat('-',@#{key},'-'), #{xpath_quote "-#{value}-"})]"
          end
        end

        result
      end.join
    end

    def calculate_specificity(selector)
      selector.reduce([0, 0, 0]) do |(a, b, c), element|
        [
          a + (element[:id] || []).length,
          b + (element[:class] || []).length + (element[:attribute] || []).length + (element[:pseudo_class] || []).length,
          c + (element[:name] && element[:name] != '*' ? 1 : 0)
        ]
      end
    end
  end
end