module Steep module Interface class Function class Params module Utils def union(*types, null: false) types << AST::Builtin.nil_type if null AST::Types::Union.build(types: types) end def intersection(*types) AST::Types::Intersection.build(types: types) end end class PositionalParams class Base attr_reader :type def initialize(type) @type = type end def ==(other) other.is_a?(self.class) && other.type == type end alias eql? == def hash self.class.hash ^ type.hash end def subst(s) ty = type.subst(s) if ty == type self else _ = self.class.new(ty) end end def var_type type end def map_type(&block) if block_given? _ = self.class.new(yield type) else enum_for(:map_type) end end end class Required < Base; end class Optional < Base; end class Rest < Base; end attr_reader :head attr_reader :tail def initialize(head:, tail:) @head = head @tail = tail end def self.required(type, tail = nil) PositionalParams.new(head: Required.new(type), tail: tail) end def self.optional(type, tail = nil) PositionalParams.new(head: Optional.new(type), tail: tail) end def self.rest(type, tail = nil) PositionalParams.new(head: Rest.new(type), tail: tail) end def to_ary [head, tail] end def map(&block) hd = yield(head) tl = tail&.map(&block) if head == hd && tail == tl self else PositionalParams.new(head: hd, tail: tl) end end def map_type(&block) if block map {|param| param.map_type(&block) } else enum_for :map_type end end def subst(s) map_type do |type| ty = type.subst(s) if ty == type type else ty end end end def ==(other) other.is_a?(PositionalParams) && other.head == head && other.tail == tail end alias eql? == def hash self.class.hash ^ head.hash ^ tail.hash end def each(&block) if block yield head tail&.each(&block) else enum_for(:each) end end def each_type if block_given? each do |param| yield param.type end else enum_for :each_type end end def size 1 + (tail&.size || 0) end def self.build(required:, optional:, rest:) params = rest ? self.rest(rest) : nil params = optional.reverse_each.inject(params) {|params, type| self.optional(type, params) } params = required.reverse_each.inject(params) {|params, type| self.required(type, params) } params end extend Utils # Calculates xs + ys. # Never fails. def self.merge_for_overload(xs, ys) x = xs&.head y = ys&.head case when x.is_a?(Required) && y.is_a?(Required) required( union(x.type, y.type), merge_for_overload(xs.tail, ys.tail) ) when x.is_a?(Required) && y.is_a?(Optional) optional( union(x.type, y.type, null: true), merge_for_overload(xs.tail, ys.tail) ) when x.is_a?(Required) && y.is_a?(Rest) optional( union(x.type, y.type, null: true), merge_for_overload(xs.tail, ys) ) when x.is_a?(Required) && !y optional( union(x.type, null: true), merge_for_overload(xs.tail, nil) ) when x.is_a?(Optional) && y.is_a?(Required) optional( union(x.type, y.type, null: true), merge_for_overload(xs.tail, ys.tail) ) when x.is_a?(Optional) && y.is_a?(Optional) optional( union(x.type, y.type), merge_for_overload(xs.tail, ys.tail) ) when x.is_a?(Optional) && y.is_a?(Rest) optional( union(x.type, y.type), merge_for_overload(xs.tail, ys) ) when x.is_a?(Optional) && !y optional( x.type, merge_for_overload(xs.tail, nil) ) # == xs when x.is_a?(Rest) && y.is_a?(Required) optional( union(x.type, y.type, null: true), merge_for_overload(xs, ys.tail) ) when x.is_a?(Rest) && y.is_a?(Optional) optional( union(x.type, y.type), merge_for_overload(xs, ys.tail) ) when x.is_a?(Rest) && y.is_a?(Rest) rest(union(x.type, y.type)) when x.is_a?(Rest) && !y xs when !x && y.is_a?(Required) optional( union(y.type, null: true), merge_for_overload(nil, ys.tail) ) when !x && y.is_a?(Optional) optional( y.type, merge_for_overload(nil, ys.tail) ) # == ys when !x && y.is_a?(Rest) ys when !x && !y nil end end # xs | ys def self.merge_for_union(xs, ys) x = xs&.head y = ys&.head case when x.is_a?(Required) && y.is_a?(Required) required( union(x.type, y.type), merge_for_union(xs.tail, ys.tail) ) when x.is_a?(Required) && !y optional( x.type, merge_for_union(xs.tail, nil) ) when x.is_a?(Required) && y.is_a?(Optional) optional( union(x.type, y.type), merge_for_union(xs.tail, ys.tail) ) when x.is_a?(Required) && y.is_a?(Rest) optional( union(x.type, y.type), merge_for_union(xs.tail, ys) ) when !x && y.is_a?(Required) optional( y.type, merge_for_union(nil, ys.tail) ) when !x && !y nil when !x && y.is_a?(Optional) PositionalParams.new(head: y, tail: merge_for_union(nil, ys.tail)) when !x && y.is_a?(Rest) ys when x.is_a?(Optional) && y.is_a?(Required) optional( union(x.type, y.type), merge_for_union(xs.tail, ys.tail) ) when x.is_a?(Optional) && !y PositionalParams.new(head: x, tail: merge_for_union(xs.tail, nil)) # == xs when x.is_a?(Optional) && y.is_a?(Optional) optional( union(x.type, y.type), merge_for_union(xs.tail, ys.tail) ) when x.is_a?(Optional) && y.is_a?(Rest) optional( union(x.type, y.type), merge_for_union(xs.tail, ys.tail) ) when x.is_a?(Rest) && y.is_a?(Required) optional( union(x.type, y.type), merge_for_union(xs, ys.tail) ) when x.is_a?(Rest) && !y xs when x.is_a?(Rest) && y.is_a?(Optional) optional( union(x.type, y.type), merge_for_union(xs, ys.tail) ) when x.is_a?(Rest) && y.is_a?(Rest) rest( union(x.type, y.type) ) end end # Calculates xs & ys. # Raises when failed. # def self.merge_for_intersection(xs, ys) x = xs&.head y = ys&.head case when x.is_a?(Required) && y.is_a?(Required) required( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys.tail) ) when x.is_a?(Required) && !y raise when x.is_a?(Required) && y.is_a?(Optional) required( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys.tail) ) when x.is_a?(Required) && y.is_a?(Rest) required( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys) ) when !x && y.is_a?(Required) raise when !x && !y nil when !x && y.is_a?(Optional) nil when !x && y.is_a?(Rest) nil when x.is_a?(Optional) && y.is_a?(Required) required( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys.tail) ) when x.is_a?(Optional) && !y nil when x.is_a?(Optional) && y.is_a?(Optional) optional( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys.tail) ) when x.is_a?(Optional) && y.is_a?(Rest) optional( intersection(x.type, y.type), merge_for_intersection(xs.tail, ys) ) when x.is_a?(Rest) && y.is_a?(Required) required( intersection(x.type, y.type), merge_for_intersection(xs, ys.tail) ) when x.is_a?(Rest) && !y nil when x.is_a?(Rest) && y.is_a?(Optional) optional( intersection(x.type, y.type), merge_for_intersection(xs, ys.tail) ) when x.is_a?(Rest) && y.is_a?(Rest) rest(intersection(x.type, y.type)) end end end class KeywordParams attr_reader :requireds attr_reader :optionals attr_reader :rest def initialize(requireds: {}, optionals: {}, rest: nil) @requireds = requireds @optionals = optionals @rest = rest end def ==(other) other.is_a?(KeywordParams) && other.requireds == requireds && other.optionals == optionals && other.rest == rest end alias eql? == def hash self.class.hash ^ requireds.hash ^ optionals.hash ^ rest.hash end def update(requireds: self.requireds, optionals: self.optionals, rest: self.rest) KeywordParams.new( requireds: requireds, optionals: optionals, rest: rest ) end def empty? requireds.empty? && optionals.empty? && rest.nil? end def each(&block) if block requireds.each(&block) optionals.each(&block) if rest yield [nil, rest] end else enum_for :each end end def each_type if block_given? each do |_, type| yield type end else enum_for :each_type end end def map_type(&block) if block rs = requireds.transform_values(&block) os = optionals.transform_values(&block) r = rest&.yield_self(&block) if requireds == rs && optionals == os && rest == r self else update(requireds: rs, optionals: os, rest: r) end else enum_for(:map_type) end end def subst(s) map_type do |type| ty = type.subst(s) if ty == type type else ty end end end def size requireds.size + optionals.size + (rest ? 1 : 0) end def keywords Set[] + requireds.keys + optionals.keys end include Utils # For overloading def +(other) requireds = {} #: Hash[Symbol, AST::Types::t] optionals = {} #: Hash[Symbol, AST::Types::t] all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys all_keys.each do |key| case when t = self.requireds[key] case when s = other.requireds[key] requireds[key] = union(t, s) when s = other.optionals[key] optionals[key] = union(t, s, null: true) when s = other.rest optionals[key] = union(t, s, null: true) else optionals[key] = union(t, null: true) end when t = self.optionals[key] case when s = other.requireds[key] optionals[key] = union(t, s, null: true) when s = other.optionals[key] optionals[key] = union(t, s) when s = other.rest optionals[key] = union(t, s) else optionals[key] = t end when t = self.rest case when s = other.requireds[key] optionals[key] = union(t, s, null: true) when s = other.optionals[key] optionals[key] = union(t, s) when s = other.rest # cannot happen else # nop end else case when s = other.requireds[key] optionals[key] = union(s, null: true) when s = other.optionals[key] optionals[key] = s when s = other.rest # nop else # cannot happen end end end if self.rest && other.rest rest = union(self.rest, other.rest) else rest = self.rest || other.rest end KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest) end # For union def |(other) requireds = {} #: Hash[Symbol, AST::Types::t] optionals = {} #: Hash[Symbol, AST::Types::t] all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys all_keys.each do |key| case when t = self.requireds[key] case when s = other.requireds[key] requireds[key] = union(t, s) when s = other.optionals[key] optionals[key] = union(t, s) when s = other.rest optionals[key] = union(t, s) else optionals[key] = t end when t = self.optionals[key] case when s = other.requireds[key] optionals[key] = union(t, s) when s = other.optionals[key] optionals[key] = union(t, s) when s = other.rest optionals[key] = union(t, s) else optionals[key] = t end when t = self.rest case when s = other.requireds[key] optionals[key] = union(t, s) when s = other.optionals[key] optionals[key] = union(t, s) when s = other.rest # cannot happen else # nop end else case when s = other.requireds[key] optionals[key] = s when s = other.optionals[key] optionals[key] = s when s = other.rest # nop else # cannot happen end end end rest = if self.rest && other.rest union(self.rest, other.rest) else self.rest || other.rest end KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest) end # For intersection def &(other) requireds = {} #: Hash[Symbol, AST::Types::t] optionals = {} #: Hash[Symbol, AST::Types::t] all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys all_keys.each do |key| case when t = self.requireds[key] case when s = other.requireds[key] requireds[key] = intersection(t, s) when s = other.optionals[key] requireds[key] = intersection(t, s) when s = other.rest requireds[key] = intersection(t, s) else return nil end when t = self.optionals[key] case when s = other.requireds[key] requireds[key] = intersection(t, s) when s = other.optionals[key] optionals[key] = intersection(t, s) when s = other.rest optionals[key] = intersection(t, s) else # nop end when t = self.rest case when s = other.requireds[key] requireds[key] = intersection(t, s) when s = other.optionals[key] optionals[key] = intersection(t, s) when s = other.rest # cannot happen else # nop end else case when s = other.requireds[key] return nil when s = other.optionals[key] # nop when s = other.rest # nop else # cannot happen end end end rest = if self.rest && other.rest intersection(self.rest, other.rest) else nil end KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest) end end def required array = [] #: Array[AST::Types::t] positional_params&.each do |param| case param when PositionalParams::Required array << param.type else break end end array end def optional array = [] #: Array[AST::Types::t] positional_params&.each do |param| case param when PositionalParams::Required # skip when PositionalParams::Optional array << param.type else break end end array end def rest positional_params&.each do |param| case param when PositionalParams::Required, PositionalParams::Optional # skip when PositionalParams::Rest return param.type end end end attr_reader :positional_params attr_reader :keyword_params def self.build(required: [], optional: [], rest: nil, required_keywords: {}, optional_keywords: {}, rest_keywords: nil) positional_params = PositionalParams.build(required: required, optional: optional, rest: rest) keyword_params = KeywordParams.new(requireds: required_keywords, optionals: optional_keywords, rest: rest_keywords) new(positional_params: positional_params, keyword_params: keyword_params) end def initialize(positional_params:, keyword_params:) @positional_params = positional_params @keyword_params = keyword_params end def update(positional_params: self.positional_params, keyword_params: self.keyword_params) self.class.new(positional_params: positional_params, keyword_params: keyword_params) end def first_param positional_params&.head end def with_first_param(param) update( positional_params: PositionalParams.new( head: param, tail: positional_params ) ) end def has_positional? positional_params ? true : false end def self.empty self.new(positional_params: nil, keyword_params: KeywordParams.new) end def ==(other) other.is_a?(self.class) && other.positional_params == positional_params && other.keyword_params == keyword_params end alias eql? == def hash self.class.hash ^ positional_params.hash ^ keyword_params.hash end def flat_unnamed_params if positional_params positional_params.each.with_object([]) do |param, types| case param when PositionalParams::Required types << [:required, param.type] when PositionalParams::Optional types << [:optional, param.type] end end else [] end end def flat_keywords required_keywords.merge(optional_keywords) end def required_keywords keyword_params.requireds end def optional_keywords keyword_params.optionals end def rest_keywords keyword_params.rest end def has_keywords? !keyword_params.empty? end def each_positional_param(&block) if block_given? if positional_params positional_params.each(&block) end else enum_for :each_positional_param end end def without_keywords update(keyword_params: KeywordParams.new) end def drop_first case when positional_params update(positional_params: positional_params.tail) when has_keywords? without_keywords() else raise "Cannot drop from empty params" end end def each_type(&block) if block positional_params&.each_type(&block) keyword_params.each_type(&block) else enum_for :each_type end end def free_variables() @fvs ||= Set.new.tap do |set| each_type do |type| set.merge(type.free_variables) end end end def closed? each_type.all?(&:closed?) end def subst(s) return self if s.empty? return self if empty? return self if each_type.none? {|t| s.apply?(t) } pp = positional_params kp = keyword_params if positional_params && positional_params.each_type.any? {|t| s.apply?(t) } pp = positional_params.subst(s) end if keyword_params && keyword_params.each_type.any? {|t| s.apply?(t) } kp = keyword_params.subst(s) end self.class.new(positional_params: pp, keyword_params: kp) end def size (positional_params&.size || 0) + keyword_params.size end def to_s required = self.required.map {|ty| ty.to_s } optional = self.optional.map {|ty| "?#{ty}" } rest = self.rest ? ["*#{self.rest}"] : [] required_keywords = keyword_params.requireds.map {|name, type| "#{name}: #{type}" } optional_keywords = keyword_params.optionals.map {|name, type| "?#{name}: #{type}"} rest_keywords = keyword_params.rest ? ["**#{keyword_params.rest}"] : [] "(#{(required + optional + rest + required_keywords + optional_keywords + rest_keywords).join(", ")})" end def map_type(&block) self.class.new( positional_params: positional_params&.map_type(&block), keyword_params: keyword_params.map_type(&block) ) end def empty? !has_positional? && !has_keywords? end # Returns true if all arguments are non-required. def optional? required.empty? && required_keywords.empty? end # self + params returns a new params for overloading. # def +(other) pp = PositionalParams.merge_for_overload(positional_params, other.positional_params) kp = keyword_params + other.keyword_params Params.new(positional_params: pp, keyword_params: kp) end # Returns the intersection between self and other. # Returns nil if the intersection cannot be computed. # # (self & other) <: self # (self & other) <: other # # `self & other` accept `arg` if `arg` is acceptable for both of `self` and `other`. # def &(other) pp = PositionalParams.merge_for_intersection(positional_params, other.positional_params) rescue return kp = keyword_params & other.keyword_params or return Params.new(positional_params: pp, keyword_params: kp) end # Returns the union between self and other. # # self <: (self | other) # other <: (self | other) # # `self | other` accept `arg` if `self` accepts `arg` or `other` accepts `arg`. # def |(other) pp = PositionalParams.merge_for_union(positional_params, other.positional_params) rescue return kp = keyword_params | other.keyword_params or return Params.new(positional_params: pp, keyword_params: kp) end end attr_reader :params attr_reader :return_type attr_reader :location def initialize(params:, return_type:, location:) @params = params @return_type = return_type @location = location end def ==(other) other.is_a?(Function) && other.params == params && other.return_type == return_type end alias eql? == def hash self.class.hash ^ params.hash ^ return_type.hash end def free_variables @fvs ||= Set[].tap do |fvs| fvs.merge(params.free_variables) fvs.merge(return_type.free_variables) end end def subst(s) return self if s.empty? ps = params.subst(s) ret = return_type.subst(s) if ps == params && ret == return_type self else Function.new( params: ps, return_type: ret, location: location ) end end def each_type(&block) if block params.each_type(&block) yield return_type else enum_for :each_type end end alias each_child each_type def map_type(&block) Function.new( params: params.map_type(&block), return_type: yield(return_type), location: location ) end def with(params: self.params, return_type: self.return_type) Function.new( params: params, return_type: return_type, location: location ) end def to_s "#{params} -> #{return_type}" end def closed? params.closed? && return_type.closed? end end end end