lib/cel/checker.rb in cel-0.1.2 vs lib/cel/checker.rb in cel-0.2.0

- old
+ new

@@ -29,14 +29,13 @@ def merge(declarations) Checker.new(@declarations ? @declarations.merge(declarations) : declarations) end - # TODO: add protobuf timestamp and duration - LOGICAL_EXPECTED_TYPES = %i[bool int uint double string bytes].freeze - ADD_EXPECTED_TYPES = %i[int uint double string bytes list].freeze - SUB_EXPECTED_TYPES = %i[int uint double].freeze + LOGICAL_EXPECTED_TYPES = %i[bool int uint double string bytes timestamp duration].freeze + ADD_EXPECTED_TYPES = %i[int uint double string bytes list duration].freeze + SUB_EXPECTED_TYPES = %i[int uint double duration].freeze MULTIDIV_EXPECTED_TYPES = %i[int uint double].freeze REMAINDER_EXPECTED_TYPES = %i[int uint].freeze def check_operation(operation) type = infer_operation_type(operation) @@ -71,52 +70,78 @@ unsupported_type(operation) end else case op - when "&&", "||", "==", "!=", "<", "<=", ">=", ">" - return TYPES[:bool] + when "&&", "||", "<", "<=", ">=", ">" + return TYPES[:bool] if find_match_all_types(LOGICAL_EXPECTED_TYPES, values) || values.include?(:any) + when "!=", "==" + return TYPES[:bool] if values.uniq.size == 1 || + values.all? { |v| v == :list } || + values.all? { |v| v == :map } || + values.include?(:any) when "in" - return TYPES[:bool] if find_match_all_types(%i[list map], values.last) + return TYPES[:bool] if find_match_all_types(%i[list map any], values.last) when "+" - if (type = find_match_all_types(ADD_EXPECTED_TYPES, values)) - return type - end + return type if (type = find_match_all_types(ADD_EXPECTED_TYPES, values)) + + return TYPES[:timestamp] if %i[timestamp duration].any? { |typ| values.first == typ } + + return values.last if values.first == :any + when "-" - if (type = find_match_all_types(SUB_EXPECTED_TYPES, values)) - return type + return type if (type = find_match_all_types(SUB_EXPECTED_TYPES, values)) + + case values.first + when TYPES[:timestamp] + return TYPES[:duration] if values.last == :timestamp + + return TYPES[:timestamp] if values.last == :duration + + return TYPES[:any] if values.last == :any + + when TYPES[:any] + return values.last end when "*", "/" - if (type = find_match_all_types(MULTIDIV_EXPECTED_TYPES, values)) - return type - end + return type if (type = find_match_all_types(MULTIDIV_EXPECTED_TYPES, values)) + + values.include?(:any) + values.find { |typ| typ != :any } || TYPES[:any] + when "%" - if (type = find_match_all_types(REMAINDER_EXPECTED_TYPES, values)) - return type - end + return type if (type = find_match_all_types(REMAINDER_EXPECTED_TYPES, values)) + + values.include?(:any) + values.find { |typ| typ != :any } || TYPES[:any] + else unsupported_type(operation) end end unsupported_type(operation) end + def infer_variable_type(var) + case var + when Identifier + check_identifier(var) + when Invoke + check_invoke(var) + else + var.type + end + end + def check_invoke(funcall, var_type = nil) var = funcall.var func = funcall.func args = funcall.args return check_standard_func(funcall) unless var - var_type ||= case var - when Identifier - check_identifier(var) - when Invoke - check_invoke(var) - else - var.type - end + var_type ||= infer_variable_type(var) case var_type when MapType # A field selection expression, e.f, can be applied both to messages and # to maps. For maps, selection is interpreted as the field being a string key. @@ -151,10 +176,12 @@ check_arity(funcall, args, 2) identifier, predicate = args unsupported_type(funcall) unless identifier.is_a?(Identifier) + identifier.type = var_type.element_type + element_checker = merge(identifier.to_sym => var_type.element_type) unsupported_type(funcall) if element_checker.check(predicate) != :bool TYPES[:bool] @@ -193,21 +220,42 @@ return TYPES[:bool] if find_match_all_types(%i[string], call(args.first)) else unsupported_type(funcall) end unsupported_operation(funcall) + when TYPES[:timestamp] + case func + when :getDate, :getDayOfMonth, :getDayOfWeek, :getDayOfYear, :getFullYear, :getHours, + :getMilliseconds, :getMinutes, :getMonth, :getSeconds + check_arity(func, args, 0..1) + return TYPES[:int] if args.empty? || (args.size.positive? && args[0] == :string) + else + unsupported_type(funcall) + end + unsupported_operation(funcall) + when TYPES[:duration] + case func + when :getMilliseconds, :getMinutes, :getHours, :getSeconds + check_arity(func, args, 0) + return TYPES[:int] + else + unsupported_type(funcall) + end + unsupported_operation(funcall) else TYPES[:any] end end CAST_ALLOWED_TYPES = { - int: %i[uint double string], # TODO: enum, timestamp + int: %i[uint double string timestamp], # TODO: enum uint: %i[int double string], - string: %i[int uint double bytes], # TODO: timestamp, duration + string: %i[int uint double bytes timestamp duration], double: %i[int uint string], bytes: %i[string], + duration: %i[string], + timestamp: %i[string], }.freeze def check_standard_func(funcall) func = funcall.func args = funcall.args @@ -221,36 +269,62 @@ unsupported_type(funcall) unless args.first.is_a?(Invoke) return TYPES[:bool] when :size check_arity(func, args, 1) - return TYPES[:int] if find_match_all_types(%i[string bytes list map], call(args.first)) - when :int, :uint, :string, :double, :bytes # :duration, :timestamp + + arg = call(args.first) + return TYPES[:int] if find_match_all_types(%i[string bytes list map], arg) + when *CAST_ALLOWED_TYPES.keys check_arity(func, args, 1) allowed_types = CAST_ALLOWED_TYPES[func] - return TYPES[func] if find_match_all_types(allowed_types, call(args.first)) + arg = call(args.first) + return TYPES[func] if find_match_all_types(allowed_types, arg) when :matches check_arity(func, args, 2) - return TYPES[:bool] if find_match_all_types(%i[string], args.map { |arg| call(arg) }) + return TYPES[:bool] if find_match_all_types(%i[string], args.map(&method(:call))) when :dyn check_arity(func, args, 1) arg_type = call(args.first) case arg_type when ListType, MapType arg_type.element_type = TYPES[:any] end return arg_type else + return check_custom_func(@declarations[func], funcall) if @declarations.key?(func) + unsupported_type(funcall) end unsupported_operation(funcall) end + def check_custom_func(func, funcall) + args = funcall.args + + unless func.is_a?(Cel::Function) + raise CheckError, "#{func} must respond to #call" unless func.respond_to?(:call) + + func = Cel::Function(&func) + end + + unless func.types.empty? + unsupported_type(funcall) unless func.types.zip(args.map(&method(:call))) + .all? do |expected_type, type| + expected_type == :any || expected_type == type + end + + return func.type + end + + unsupported_operation(funcall) + end + def check_identifier(identifier) - return unless identifier.type == :any + return identifier.type unless identifier.type == :any return TYPES[:type] if Cel::PRIMITIVE_TYPES.include?(identifier.to_sym) id_type = infer_dec_type(identifier.id) @@ -260,10 +334,14 @@ id_type end def check_condition(condition) + if_type = call(condition.if) + + raise CheckError, "`#{condition.if}` must evaluate to a bool" unless if_type == :bool + then_type = call(condition.then) else_type = call(condition.else) return then_type if then_type == else_type @@ -284,11 +362,11 @@ def convert(typ) case typ when Symbol TYPES[typ] or - raise CheckError, "#{typ} is not aa valid type" + raise CheckError, "#{typ} is not a valid type" else typ end end @@ -305,10 +383,10 @@ TYPES[type] end def check_arity(func, args, arity) - return if args.size == arity + return if arity === args.size # rubocop:disable Style/CaseEquality raise CheckError, "`#{func}` invoked with wrong number of arguments (should be #{arity})" end def unsupported_type(op)