lib/dry/schema/predicate_inferrer.rb in dry-schema-0.3.0 vs lib/dry/schema/predicate_inferrer.rb in dry-schema-0.4.0

- old
+ new

@@ -10,10 +10,11 @@ # @api private class PredicateInferrer extend Dry::Core::Cache TYPE_TO_PREDICATE = { + DateTime => :date_time?, FalseClass => :false?, Integer => :int?, NilClass => :nil?, String => :str?, TrueClass => :true? @@ -25,42 +26,72 @@ # Compiler reduces type AST into a list of predicates # # @api private class Compiler + # @!attribute [r] registry + # @return [PredicateRegistry] + # @api private + attr_reader :registry + + # @api private + def initialize(registry) + @registry = registry + end + + # @api private + def infer_predicate(type) + TYPE_TO_PREDICATE.fetch(type) { :"#{type.name.split('::').last.downcase}?" } + end + + # @api private def visit(node) meth, rest = node public_send(:"visit_#{meth}", rest) end - def visit_definition(node) + # @api private + def visit_nominal(node) type = node[0] + predicate = infer_predicate(type) - TYPE_TO_PREDICATE.fetch(type) { - :"#{type.name.split('::').last.downcase}?" - } + if registry.key?(predicate) + predicate + else + { type?: type } + end end - def visit_array(*) + # @api private + def visit_hash(_) + :hash? + end + + # @api private + def visit_array(_) :array? end + # @api private def visit_safe(node) other, * = node visit(other) end + # @api private def visit_constructor(node) other, * = node visit(other) end + # @api private def visit_enum(node) other, * = node visit(other) end + # @api private def visit_sum(node) left, right = node predicates = [visit(left), visit(right)] @@ -69,30 +100,41 @@ else predicates end end + # @api private def visit_constrained(node) other, * = node visit(other) end end + # @!attribute [r] compiler + # @return [Compiler] + # @api private + attr_reader :compiler + + # @api private + def initialize(registry) + @compiler = Compiler.new(registry) + end + # Infer predicate identifier from the provided type # # @return [Symbol] # # @api private - def self.[](type) - fetch_or_store(type.hash) { - predicates = Array(compiler.visit(type.to_ast)).flatten - Array(REDUCED_TYPES[predicates] || predicates).flatten - } - end + def [](type) + self.class.fetch_or_store(type.hash) do + predicates = compiler.visit(type.to_ast) - # @api private - def self.compiler - @compiler ||= Compiler.new + if predicates.is_a?(Hash) + predicates + else + Array(REDUCED_TYPES[predicates] || predicates).flatten + end + end end end end end