lib/torch/native/function.rb in torch-rb-0.1.5 vs lib/torch/native/function.rb in torch-rb-0.1.6
- old
+ new
@@ -33,21 +33,52 @@
pos = false
next
end
t, _, k = a.rpartition(" ")
k, d = k.split("=")
- d = d.to_i if d.to_i.to_s == d
- d = true if d == "True"
- d = false if d == "False"
- d = nil if d == "None"
- args << {name: k, type: t, default: d, pos: pos}
+ has_default = !d.nil?
+
+ if d
+ d =
+ case d
+ when "True"
+ true
+ when "False"
+ false
+ when "None"
+ nil
+ when /\A\-?\d+\z/
+ d.to_i
+ when "[]"
+ []
+ when "[0,1]"
+ [0, 1]
+ when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
+ d.to_f
+ when "Mean"
+ "mean"
+ when "contiguous_format"
+ d
+ when "long"
+ :long
+ else
+ raise "Unknown default: #{d}"
+ end
+ end
+
+ next if t == "Generator?"
+ args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
end
args
end
end
def out_size
@out_size ||= func.split("->").last.count("!")
+ end
+
+ def ret_size
+ @ret_size ||= func.split("->").last.split(", ").size
end
def out?
out_size > 0 && base_name[-1] != "_"
end