lib/torch/native/function.rb in torch-rb-0.1.5 vs lib/torch/native/function.rb in torch-rb-0.1.6

- old
+ new

@@ -33,21 +33,52 @@ pos = false next end t, _, k = a.rpartition(" ") k, d = k.split("=") - d = d.to_i if d.to_i.to_s == d - d = true if d == "True" - d = false if d == "False" - d = nil if d == "None" - args << {name: k, type: t, default: d, pos: pos} + has_default = !d.nil? + + if d + d = + case d + when "True" + true + when "False" + false + when "None" + nil + when /\A\-?\d+\z/ + d.to_i + when "[]" + [] + when "[0,1]" + [0, 1] + when /\A\de\-\d+\z/, /\A\d+\.\d+\z/ + d.to_f + when "Mean" + "mean" + when "contiguous_format" + d + when "long" + :long + else + raise "Unknown default: #{d}" + end + end + + next if t == "Generator?" + args << {name: k, type: t, default: d, pos: pos, has_default: has_default} end args end end def out_size @out_size ||= func.split("->").last.count("!") + end + + def ret_size + @ret_size ||= func.split("->").last.split(", ").size end def out? out_size > 0 && base_name[-1] != "_" end