lib/cel/checker.rb in cel-0.1.2 vs lib/cel/checker.rb in cel-0.2.0
- old
+ new
@@ -29,14 +29,13 @@
def merge(declarations)
Checker.new(@declarations ? @declarations.merge(declarations) : declarations)
end
- # TODO: add protobuf timestamp and duration
- LOGICAL_EXPECTED_TYPES = %i[bool int uint double string bytes].freeze
- ADD_EXPECTED_TYPES = %i[int uint double string bytes list].freeze
- SUB_EXPECTED_TYPES = %i[int uint double].freeze
+ LOGICAL_EXPECTED_TYPES = %i[bool int uint double string bytes timestamp duration].freeze
+ ADD_EXPECTED_TYPES = %i[int uint double string bytes list duration].freeze
+ SUB_EXPECTED_TYPES = %i[int uint double duration].freeze
MULTIDIV_EXPECTED_TYPES = %i[int uint double].freeze
REMAINDER_EXPECTED_TYPES = %i[int uint].freeze
def check_operation(operation)
type = infer_operation_type(operation)
@@ -71,52 +70,78 @@
unsupported_type(operation)
end
else
case op
- when "&&", "||", "==", "!=", "<", "<=", ">=", ">"
- return TYPES[:bool]
+ when "&&", "||", "<", "<=", ">=", ">"
+ return TYPES[:bool] if find_match_all_types(LOGICAL_EXPECTED_TYPES, values) || values.include?(:any)
+ when "!=", "=="
+ return TYPES[:bool] if values.uniq.size == 1 ||
+ values.all? { |v| v == :list } ||
+ values.all? { |v| v == :map } ||
+ values.include?(:any)
when "in"
- return TYPES[:bool] if find_match_all_types(%i[list map], values.last)
+ return TYPES[:bool] if find_match_all_types(%i[list map any], values.last)
when "+"
- if (type = find_match_all_types(ADD_EXPECTED_TYPES, values))
- return type
- end
+ return type if (type = find_match_all_types(ADD_EXPECTED_TYPES, values))
+
+ return TYPES[:timestamp] if %i[timestamp duration].any? { |typ| values.first == typ }
+
+ return values.last if values.first == :any
+
when "-"
- if (type = find_match_all_types(SUB_EXPECTED_TYPES, values))
- return type
+ return type if (type = find_match_all_types(SUB_EXPECTED_TYPES, values))
+
+ case values.first
+ when TYPES[:timestamp]
+ return TYPES[:duration] if values.last == :timestamp
+
+ return TYPES[:timestamp] if values.last == :duration
+
+ return TYPES[:any] if values.last == :any
+
+ when TYPES[:any]
+ return values.last
end
when "*", "/"
- if (type = find_match_all_types(MULTIDIV_EXPECTED_TYPES, values))
- return type
- end
+ return type if (type = find_match_all_types(MULTIDIV_EXPECTED_TYPES, values))
+
+ values.include?(:any)
+ values.find { |typ| typ != :any } || TYPES[:any]
+
when "%"
- if (type = find_match_all_types(REMAINDER_EXPECTED_TYPES, values))
- return type
- end
+ return type if (type = find_match_all_types(REMAINDER_EXPECTED_TYPES, values))
+
+ values.include?(:any)
+ values.find { |typ| typ != :any } || TYPES[:any]
+
else
unsupported_type(operation)
end
end
unsupported_type(operation)
end
+ def infer_variable_type(var)
+ case var
+ when Identifier
+ check_identifier(var)
+ when Invoke
+ check_invoke(var)
+ else
+ var.type
+ end
+ end
+
def check_invoke(funcall, var_type = nil)
var = funcall.var
func = funcall.func
args = funcall.args
return check_standard_func(funcall) unless var
- var_type ||= case var
- when Identifier
- check_identifier(var)
- when Invoke
- check_invoke(var)
- else
- var.type
- end
+ var_type ||= infer_variable_type(var)
case var_type
when MapType
# A field selection expression, e.f, can be applied both to messages and
# to maps. For maps, selection is interpreted as the field being a string key.
@@ -151,10 +176,12 @@
check_arity(funcall, args, 2)
identifier, predicate = args
unsupported_type(funcall) unless identifier.is_a?(Identifier)
+ identifier.type = var_type.element_type
+
element_checker = merge(identifier.to_sym => var_type.element_type)
unsupported_type(funcall) if element_checker.check(predicate) != :bool
TYPES[:bool]
@@ -193,21 +220,42 @@
return TYPES[:bool] if find_match_all_types(%i[string], call(args.first))
else
unsupported_type(funcall)
end
unsupported_operation(funcall)
+ when TYPES[:timestamp]
+ case func
+ when :getDate, :getDayOfMonth, :getDayOfWeek, :getDayOfYear, :getFullYear, :getHours,
+ :getMilliseconds, :getMinutes, :getMonth, :getSeconds
+ check_arity(func, args, 0..1)
+ return TYPES[:int] if args.empty? || (args.size.positive? && args[0] == :string)
+ else
+ unsupported_type(funcall)
+ end
+ unsupported_operation(funcall)
+ when TYPES[:duration]
+ case func
+ when :getMilliseconds, :getMinutes, :getHours, :getSeconds
+ check_arity(func, args, 0)
+ return TYPES[:int]
+ else
+ unsupported_type(funcall)
+ end
+ unsupported_operation(funcall)
else
TYPES[:any]
end
end
CAST_ALLOWED_TYPES = {
- int: %i[uint double string], # TODO: enum, timestamp
+ int: %i[uint double string timestamp], # TODO: enum
uint: %i[int double string],
- string: %i[int uint double bytes], # TODO: timestamp, duration
+ string: %i[int uint double bytes timestamp duration],
double: %i[int uint string],
bytes: %i[string],
+ duration: %i[string],
+ timestamp: %i[string],
}.freeze
def check_standard_func(funcall)
func = funcall.func
args = funcall.args
@@ -221,36 +269,62 @@
unsupported_type(funcall) unless args.first.is_a?(Invoke)
return TYPES[:bool]
when :size
check_arity(func, args, 1)
- return TYPES[:int] if find_match_all_types(%i[string bytes list map], call(args.first))
- when :int, :uint, :string, :double, :bytes # :duration, :timestamp
+
+ arg = call(args.first)
+ return TYPES[:int] if find_match_all_types(%i[string bytes list map], arg)
+ when *CAST_ALLOWED_TYPES.keys
check_arity(func, args, 1)
allowed_types = CAST_ALLOWED_TYPES[func]
- return TYPES[func] if find_match_all_types(allowed_types, call(args.first))
+ arg = call(args.first)
+ return TYPES[func] if find_match_all_types(allowed_types, arg)
when :matches
check_arity(func, args, 2)
- return TYPES[:bool] if find_match_all_types(%i[string], args.map { |arg| call(arg) })
+ return TYPES[:bool] if find_match_all_types(%i[string], args.map(&method(:call)))
when :dyn
check_arity(func, args, 1)
arg_type = call(args.first)
case arg_type
when ListType, MapType
arg_type.element_type = TYPES[:any]
end
return arg_type
else
+ return check_custom_func(@declarations[func], funcall) if @declarations.key?(func)
+
unsupported_type(funcall)
end
unsupported_operation(funcall)
end
+ def check_custom_func(func, funcall)
+ args = funcall.args
+
+ unless func.is_a?(Cel::Function)
+ raise CheckError, "#{func} must respond to #call" unless func.respond_to?(:call)
+
+ func = Cel::Function(&func)
+ end
+
+ unless func.types.empty?
+ unsupported_type(funcall) unless func.types.zip(args.map(&method(:call)))
+ .all? do |expected_type, type|
+ expected_type == :any || expected_type == type
+ end
+
+ return func.type
+ end
+
+ unsupported_operation(funcall)
+ end
+
def check_identifier(identifier)
- return unless identifier.type == :any
+ return identifier.type unless identifier.type == :any
return TYPES[:type] if Cel::PRIMITIVE_TYPES.include?(identifier.to_sym)
id_type = infer_dec_type(identifier.id)
@@ -260,10 +334,14 @@
id_type
end
def check_condition(condition)
+ if_type = call(condition.if)
+
+ raise CheckError, "`#{condition.if}` must evaluate to a bool" unless if_type == :bool
+
then_type = call(condition.then)
else_type = call(condition.else)
return then_type if then_type == else_type
@@ -284,11 +362,11 @@
def convert(typ)
case typ
when Symbol
TYPES[typ] or
- raise CheckError, "#{typ} is not aa valid type"
+ raise CheckError, "#{typ} is not a valid type"
else
typ
end
end
@@ -305,10 +383,10 @@
TYPES[type]
end
def check_arity(func, args, arity)
- return if args.size == arity
+ return if arity === args.size # rubocop:disable Style/CaseEquality
raise CheckError, "`#{func}` invoked with wrong number of arguments (should be #{arity})"
end
def unsupported_type(op)