lib/dry/schema/predicate_inferrer.rb in dry-schema-0.3.0 vs lib/dry/schema/predicate_inferrer.rb in dry-schema-0.4.0
- old
+ new
@@ -10,10 +10,11 @@
# @api private
class PredicateInferrer
extend Dry::Core::Cache
TYPE_TO_PREDICATE = {
+ DateTime => :date_time?,
FalseClass => :false?,
Integer => :int?,
NilClass => :nil?,
String => :str?,
TrueClass => :true?
@@ -25,42 +26,72 @@
# Compiler reduces type AST into a list of predicates
#
# @api private
class Compiler
+ # @!attribute [r] registry
+ # @return [PredicateRegistry]
+ # @api private
+ attr_reader :registry
+
+ # @api private
+ def initialize(registry)
+ @registry = registry
+ end
+
+ # @api private
+ def infer_predicate(type)
+ TYPE_TO_PREDICATE.fetch(type) { :"#{type.name.split('::').last.downcase}?" }
+ end
+
+ # @api private
def visit(node)
meth, rest = node
public_send(:"visit_#{meth}", rest)
end
- def visit_definition(node)
+ # @api private
+ def visit_nominal(node)
type = node[0]
+ predicate = infer_predicate(type)
- TYPE_TO_PREDICATE.fetch(type) {
- :"#{type.name.split('::').last.downcase}?"
- }
+ if registry.key?(predicate)
+ predicate
+ else
+ { type?: type }
+ end
end
- def visit_array(*)
+ # @api private
+ def visit_hash(_)
+ :hash?
+ end
+
+ # @api private
+ def visit_array(_)
:array?
end
+ # @api private
def visit_safe(node)
other, * = node
visit(other)
end
+ # @api private
def visit_constructor(node)
other, * = node
visit(other)
end
+ # @api private
def visit_enum(node)
other, * = node
visit(other)
end
+ # @api private
def visit_sum(node)
left, right = node
predicates = [visit(left), visit(right)]
@@ -69,30 +100,41 @@
else
predicates
end
end
+ # @api private
def visit_constrained(node)
other, * = node
visit(other)
end
end
+ # @!attribute [r] compiler
+ # @return [Compiler]
+ # @api private
+ attr_reader :compiler
+
+ # @api private
+ def initialize(registry)
+ @compiler = Compiler.new(registry)
+ end
+
# Infer predicate identifier from the provided type
#
# @return [Symbol]
#
# @api private
- def self.[](type)
- fetch_or_store(type.hash) {
- predicates = Array(compiler.visit(type.to_ast)).flatten
- Array(REDUCED_TYPES[predicates] || predicates).flatten
- }
- end
+ def [](type)
+ self.class.fetch_or_store(type.hash) do
+ predicates = compiler.visit(type.to_ast)
- # @api private
- def self.compiler
- @compiler ||= Compiler.new
+ if predicates.is_a?(Hash)
+ predicates
+ else
+ Array(REDUCED_TYPES[predicates] || predicates).flatten
+ end
+ end
end
end
end
end