lib/torch/native/parser.rb in torch-rb-0.1.5 vs lib/torch/native/parser.rb in torch-rb-0.1.6
- old
+ new
@@ -2,69 +2,106 @@
module Native
class Parser
def initialize(functions)
@functions = functions
@name = @functions.first.ruby_name
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:default].nil? } }.min
+ @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
end
def parse(args, options)
candidates = @functions.dup
+ # remove nil
+ while args.any? && args.last.nil?
+ args.pop
+ end
+
+ # TODO account for args passed as options here
if args.size < @min_args || args.size > @max_args
expected = String.new(@min_args.to_s)
expected += "..#{@max_args}" if @max_args != @min_args
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
end
+ candidates.reject! { |f| args.size > f.args.size }
+
+ # exclude functions missing required options
+ candidates.reject! do |func|
+ # TODO make more generic
+ func.out? && !options[:out]
+ end
+
+ # handle out with multiple
+ # there should only be one match, so safe to modify all
+ out_func = candidates.find { |f| f.out? }
+ if out_func && out_func.out_size > 1 && options[:out]
+ out_args = out_func.args.last(2).map { |a| a[:name] }
+ out_args.zip(options.delete(:out)).each do |k, v|
+ options[k.to_sym] = v
+ end
+ candidates = [out_func]
+ end
+
# exclude functions where options don't match
options.each do |k, v|
candidates.select! do |func|
func.args.any? { |a| a[:name] == k.to_s }
end
# TODO show all bad keywords at once like Ruby?
return {error: "unknown keyword: #{k}"} if candidates.empty?
end
- # exclude functions missing required options
- candidates.reject! do |func|
- # TODO make more generic
- func.out? && !options[:out]
- end
-
final_values = {}
# check args
candidates.select! do |func|
good = true
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
func.args.each do |fa|
- values[fa[:name]] ||= fa[:default]
+ values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
end
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
- values.each do |k, v|
+ values.each_key do |k|
+ v = values[k]
t = arg_types[k].split("(").first
+
good =
case t
when "Tensor"
v.is_a?(Tensor)
+ when "Tensor?"
+ v.nil? || v.is_a?(Tensor)
when "Tensor[]"
- v.all? { |v2| v2.is_a?(Tensor) }
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
when "int"
- v.is_a?(Integer)
- when "int[]"
- v.all? { |v2| v2.is_a?(Integer) }
+ if k == "reduction"
+ v.is_a?(String)
+ else
+ v.is_a?(Integer)
+ end
+ when "float"
+ v.is_a?(Numeric)
+ when /int\[.*\]/
+ if v.is_a?(Integer)
+ size = t[4..-2]
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
+ v = [v] * size.to_i
+ values[k] = v
+ end
+ v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
when "Scalar"
v.is_a?(Numeric)
+ when "ScalarType?"
+ v.nil?
when "bool"
v == true || v == false
else
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
+ raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
end
if !good
if candidates.size == 1
k = "input" if k == "self"