lib/torch/native/parser.rb in torch-rb-0.3.6 vs lib/torch/native/parser.rb in torch-rb-0.3.7
- old
+ new
@@ -4,18 +4,28 @@
def initialize(functions)
@functions = functions
@name = @functions.first.ruby_name
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
+ @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
end
+ # TODO improve performance
+ # possibly move to C++ (see python_arg_parser.cpp)
def parse(args, options)
candidates = @functions.dup
- # remove nil
- while args.any? && args.last.nil?
- args.pop
+ # TODO check candidates individually to see if they match
+ if @int_array_first
+ int_args = []
+ while args.first.is_a?(Integer)
+ int_args << args.shift
+ end
+ if int_args.any?
+ raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
+ args.unshift(int_args)
+ end
end
# TODO account for args passed as options here
if args.size < @min_args || args.size > @max_args
expected = String.new(@min_args.to_s)
@@ -23,121 +33,80 @@
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
end
candidates.reject! { |f| args.size > f.args.size }
- # exclude functions missing required options
- candidates.reject! do |func|
- # TODO make more generic
- func.out? && !options[:out]
- end
-
# handle out with multiple
# there should only be one match, so safe to modify all
- out_func = candidates.find { |f| f.out? }
- if out_func && out_func.out_size > 1 && options[:out]
- out_args = out_func.args.last(2).map { |a| a[:name] }
- out_args.zip(options.delete(:out)).each do |k, v|
- options[k.to_sym] = v
+ if options[:out]
+ if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
+ out_args = out_func.args.last(2).map { |a| a[:name] }
+ out_args.zip(options.delete(:out)).each do |k, v|
+ options[k] = v
+ end
+ candidates = [out_func]
end
- candidates = [out_func]
+ else
+ # exclude functions missing required options
+ candidates.reject!(&:out?)
end
- # exclude functions where options don't match
- options.each do |k, v|
- candidates.select! do |func|
- func.args.any? { |a| a[:name] == k.to_s }
- end
- # TODO show all bad keywords at once like Ruby?
- return {error: "unknown keyword: #{k}"} if candidates.empty?
- end
+ final_values = nil
- final_values = {}
-
# check args
- candidates.select! do |func|
+ while (func = candidates.shift)
good = true
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
- func.args.each do |fa|
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
+ # set values
+ # TODO use array instead of hash?
+ values = {}
+ args.each_with_index do |a, i|
+ values[func.arg_names[i]] = a
end
+ options.each do |k, v|
+ values[k] = v
+ end
+ func.arg_defaults.each do |k, v|
+ values[k] = v unless values.key?(k)
+ end
+ func.int_array_lengths.each do |k, len|
+ values[k] = [values[k]] * len if values[k].is_a?(Integer)
+ end
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
+ arg_checkers = func.arg_checkers
values.each_key do |k|
- v = values[k]
- t = arg_types[k].split("(").first
-
- good =
- case t
- when "Tensor"
- v.is_a?(Tensor)
- when "Tensor?"
- v.nil? || v.is_a?(Tensor)
- when "Tensor[]", "Tensor?[]"
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
- when "int"
- if k == "reduction"
- v.is_a?(String)
- else
- v.is_a?(Integer)
- end
- when "int?"
- v.is_a?(Integer) || v.nil?
- when "float?"
- v.is_a?(Numeric) || v.nil?
- when "bool?"
- v == true || v == false || v.nil?
- when "float"
- v.is_a?(Numeric)
- when /int\[.*\]/
- if v.is_a?(Integer)
- size = t[4..-2]
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
- v = [v] * size.to_i
- values[k] = v
- end
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
- when "Scalar"
- v.is_a?(Numeric)
- when "Scalar?"
- v.is_a?(Numeric) || v.nil?
- when "ScalarType"
- false # not supported yet
- when "ScalarType?"
- v.nil?
- when "bool"
- v == true || v == false
- when "str"
- v.is_a?(String)
- else
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
+ unless arg_checkers.key?(k)
+ good = false
+ if candidates.empty?
+ # TODO show all bad keywords at once like Ruby?
+ return {error: "unknown keyword: #{k}"}
end
+ break
+ end
- if !good
- if candidates.size == 1
- k = "input" if k == "self"
+ unless arg_checkers[k].call(values[k])
+ good = false
+ if candidates.empty?
+ t = func.arg_types[k]
+ k = :input if k == :self
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
end
break
end
end
if good
final_values = values
+ break
end
-
- good
end
- if candidates.size != 1
+ unless final_values
raise Error, "This should never happen. Please report a bug with #{@name}."
end
- func = candidates.first
- args = func.args.map { |a| final_values[a[:name]] }
+ args = func.arg_names.map { |k| final_values[k] }
args << TensorOptions.new.dtype(6) if func.tensor_options
{
name: func.cpp_name,
args: args
}