module Steep
  module Subtyping
    class VariableVariance
      attr_reader :covariants
      attr_reader :contravariants

      def initialize(covariants:, contravariants:)
        @covariants = covariants
        @contravariants = contravariants
      end

      def covariant?(var)
        covariants.member?(var) && !contravariants.member?(var)
      end

      def contravariant?(var)
        contravariants.member?(var) && !covariants.member?(var)
      end

      def invariant?(var)
        covariants.member?(var) && contravariants.member?(var)
      end

      def self.from_method_type(method_type)
        covariants = Set.new
        contravariants = Set.new

        add_params(method_type.params, block: false, contravariants: contravariants, covariants: covariants)
        add_type(method_type.return_type, variance: :covariant, covariants: covariants, contravariants: contravariants)

        method_type.block&.yield_self do |block|
          add_params(block.params, block: true, contravariants: contravariants, covariants: covariants)
          add_type(block.return_type, variance: :contravariant, covariants: covariants, contravariants: contravariants)
        end

        new(covariants: covariants, contravariants: contravariants)
      end

      def self.add_params(params, block:, covariants:, contravariants:)
        params.each_type do |type|
          add_type(type, variance: block ? :contravariant : :covariant, covariants: covariants, contravariants: contravariants)
        end
      end

      def self.add_type(type, variance:, covariants:, contravariants:)
        case type
        when AST::Types::Var
          case variance
          when :covariant
            covariants << type.name
          when :contravariant
            contravariants << type.name
          when :invariant
            covariants << type.name
            contravariants << type.name
          end
        when AST::Types::Union, AST::Types::Intersection
          type.types.each do |ty|
            add_type(ty, variance: variance, covariants: covariants, contravariants: contravariants)
          end
        when AST::Types::Name
          type.args.each do |arg|
            add_type(arg, variance: :invariant, covariants: covariants, contravariants: contravariants)
          end
        end
      end
    end
  end
end