lib/torch/native/parser.rb in torch-rb-0.3.6 vs lib/torch/native/parser.rb in torch-rb-0.3.7

- old
+ new

@@ -4,18 +4,28 @@ def initialize(functions) @functions = functions @name = @functions.first.ruby_name @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max + @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" } end + # TODO improve performance + # possibly move to C++ (see python_arg_parser.cpp) def parse(args, options) candidates = @functions.dup - # remove nil - while args.any? && args.last.nil? - args.pop + # TODO check candidates individually to see if they match + if @int_array_first + int_args = [] + while args.first.is_a?(Integer) + int_args << args.shift + end + if int_args.any? + raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any? + args.unshift(int_args) + end end # TODO account for args passed as options here if args.size < @min_args || args.size > @max_args expected = String.new(@min_args.to_s) @@ -23,121 +33,80 @@ return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"} end candidates.reject! { |f| args.size > f.args.size } - # exclude functions missing required options - candidates.reject! do |func| - # TODO make more generic - func.out? && !options[:out] - end - # handle out with multiple # there should only be one match, so safe to modify all - out_func = candidates.find { |f| f.out? } - if out_func && out_func.out_size > 1 && options[:out] - out_args = out_func.args.last(2).map { |a| a[:name] } - out_args.zip(options.delete(:out)).each do |k, v| - options[k.to_sym] = v + if options[:out] + if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1 + out_args = out_func.args.last(2).map { |a| a[:name] } + out_args.zip(options.delete(:out)).each do |k, v| + options[k] = v + end + candidates = [out_func] end - candidates = [out_func] + else + # exclude functions missing required options + candidates.reject!(&:out?) end - # exclude functions where options don't match - options.each do |k, v| - candidates.select! do |func| - func.args.any? { |a| a[:name] == k.to_s } - end - # TODO show all bad keywords at once like Ruby? - return {error: "unknown keyword: #{k}"} if candidates.empty? - end + final_values = nil - final_values = {} - # check args - candidates.select! do |func| + while (func = candidates.shift) good = true - values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h - values.merge!(options.map { |k, v| [k.to_s, v] }.to_h) - func.args.each do |fa| - values[fa[:name]] = fa[:default] if values[fa[:name]].nil? + # set values + # TODO use array instead of hash? + values = {} + args.each_with_index do |a, i| + values[func.arg_names[i]] = a end + options.each do |k, v| + values[k] = v + end + func.arg_defaults.each do |k, v| + values[k] = v unless values.key?(k) + end + func.int_array_lengths.each do |k, len| + values[k] = [values[k]] * len if values[k].is_a?(Integer) + end - arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h + arg_checkers = func.arg_checkers values.each_key do |k| - v = values[k] - t = arg_types[k].split("(").first - - good = - case t - when "Tensor" - v.is_a?(Tensor) - when "Tensor?" - v.nil? || v.is_a?(Tensor) - when "Tensor[]", "Tensor?[]" - v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } - when "int" - if k == "reduction" - v.is_a?(String) - else - v.is_a?(Integer) - end - when "int?" - v.is_a?(Integer) || v.nil? - when "float?" - v.is_a?(Numeric) || v.nil? - when "bool?" - v == true || v == false || v.nil? - when "float" - v.is_a?(Numeric) - when /int\[.*\]/ - if v.is_a?(Integer) - size = t[4..-2] - raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/ - v = [v] * size.to_i - values[k] = v - end - v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } - when "Scalar" - v.is_a?(Numeric) - when "Scalar?" - v.is_a?(Numeric) || v.nil? - when "ScalarType" - false # not supported yet - when "ScalarType?" - v.nil? - when "bool" - v == true || v == false - when "str" - v.is_a?(String) - else - raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}." + unless arg_checkers.key?(k) + good = false + if candidates.empty? + # TODO show all bad keywords at once like Ruby? + return {error: "unknown keyword: #{k}"} end + break + end - if !good - if candidates.size == 1 - k = "input" if k == "self" + unless arg_checkers[k].call(values[k]) + good = false + if candidates.empty? + t = func.arg_types[k] + k = :input if k == :self return {error: "#{@name}(): argument '#{k}' must be #{t}"} end break end end if good final_values = values + break end - - good end - if candidates.size != 1 + unless final_values raise Error, "This should never happen. Please report a bug with #{@name}." end - func = candidates.first - args = func.args.map { |a| final_values[a[:name]] } + args = func.arg_names.map { |k| final_values[k] } args << TensorOptions.new.dtype(6) if func.tensor_options { name: func.cpp_name, args: args }