module Steep module TypeInference class TypeEnv attr_reader :subtyping attr_reader :lvar_types attr_reader :const_types attr_reader :gvar_types attr_reader :ivar_types attr_reader :const_env def initialize(subtyping:, const_env:) @subtyping = subtyping @lvar_types = {} @const_types = {} @gvar_types = {} @ivar_types = {} @const_env = const_env end def initialize_copy(other) @subtyping = other.subtyping @lvar_types = other.lvar_types.dup @const_types = other.const_types.dup @gvar_types = other.gvar_types.dup @ivar_types = other.ivar_types.dup @const_env = other.const_env end def self.build(annotations:, signatures:, subtyping:, const_env:) new(subtyping: subtyping, const_env: const_env).tap do |env| annotations.var_types.each do |name, annot| env.set(lvar: name, type: subtyping.builder.absolute_type(annot.type, current: const_env.current_namespace)) end annotations.ivar_types.each do |name, type| env.set(ivar: name, type: subtyping.builder.absolute_type(type, current: const_env.current_namespace)) end annotations.const_types.each do |name, type| env.set(const: name, type: subtyping.builder.absolute_type(type, current: const_env.current_namespace)) end signatures.globals.each do |name, annot| type = subtyping.builder.absolute_type(annot.type, current: nil) env.set(gvar: name, type: type) end end end def with_annotations(lvar_types: {}, ivar_types: {}, const_types: {}, gvar_types: {}, &block) dup.tap do |env| merge!(original_env: env.lvar_types, override_env: lvar_types, &block) merge!(original_env: env.ivar_types, override_env: ivar_types, &block) merge!(original_env: env.gvar_types, override_env: gvar_types, &block) const_types.each do |name, annotated_type| original_type = self.const_types[name] || const_env.lookup(name) if original_type assert_annotation name, original_type: original_type, annotated_type: annotated_type, &block end env.const_types[name] = annotated_type end end end def join!(envs) lvars = {} common_vars = envs.map {|env| Set.new(env.lvar_types.keys) }.inject {|a, b| a & b } envs.each do |env| env.lvar_types.each do |name, type| unless lvar_types.key?(name) lvars[name] = [] unless lvars[name] lvars[name] << type end end end lvars.each do |name, types| if lvar_types.key?(name) || common_vars.member?(name) set(lvar: name, type: AST::Types::Union.build(types: types)) else set(lvar: name, type: AST::Types::Union.build(types: types + [AST::Types::Name.new_instance(name: "::NilClass")])) end end end # @type method assert: (const: ModuleName) { () -> void } -> AST::Type # | (gvar: Symbol) { () -> void } -> AST::Type # | (ivar: Symbol) { () -> void } -> AST::Type # | (lvar: Symbol) { () -> AST::Type | nil } -> AST::Type def get(lvar: nil, const: nil, gvar: nil, ivar: nil) case when lvar lvar_name(lvar).yield_self do |name| if lvar_types.key?(name) lvar_types[name] else ty = yield lvar_types[name] = ty || AST::Types::Any.new end end when const if const_types.key?(const) const_types[const] else const_env.lookup(const).yield_self do |type| if type type else yield AST::Types::Any.new end end end else lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary| if dictionary.key?(var_name) dictionary[var_name] else yield AST::Types::Any.new end end end end def set(lvar: nil, const: nil, gvar: nil, ivar: nil, type:) case when lvar lvar_name(lvar).yield_self do |name| lvar_types[name] = type end when const const_types[const] = type else lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary| dictionary[var_name] = type end end end # @type method assign: (const: ModuleName, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type # | (gvar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type # | (ivar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type # | (lvar: Symbol | LabeledName, type: AST::Type) { (Subtyping::Result::Failure) -> void } -> AST::Type def assign(lvar: nil, const: nil, gvar: nil, ivar: nil, type:, &block) case when lvar yield_self do name = lvar_name(lvar) var_type = lvar_types[name] if var_type assert_assign(var_type: var_type, lhs_type: type, &block) else lvar_types[name] = type end end when const yield_self do const_type = const_types[const] || const_env.lookup(const) if const_type assert_assign(var_type: const_type, lhs_type: type, &block) else yield nil AST::Types::Any.new end end else lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary| if dictionary.key?(var_name) assert_assign(var_type: dictionary[var_name], lhs_type: type, &block) else yield nil AST::Types::Any.new end end end end def lookup_dictionary(ivar:, gvar:) case when ivar yield ivar, ivar_types when gvar yield gvar, gvar_types end end def lvar_name(lvar) case lvar when Symbol lvar when ASTUtils::Labeling::LabeledName lvar.name end end def assert_assign(var_type:, lhs_type:) relation = Subtyping::Relation.new(sub_type: lhs_type, super_type: var_type) constraints = Subtyping::Constraints.new(unknowns: Set.new) subtyping.check(relation, constraints: constraints).else do |result| yield result end var_type end def merge!(original_env:, override_env:, &block) original_env.merge!(override_env) do |name, original_type, override_type| assert_annotation name, annotated_type: override_type, original_type: original_type, &block end end def assert_annotation(name, annotated_type:, original_type:) relation = Subtyping::Relation.new(sub_type: annotated_type, super_type: original_type) constraints = Subtyping::Constraints.new(unknowns: Set.new) subtyping.check(relation, constraints: constraints).else do |result| yield name, relation, result end annotated_type end end end end