lib/torch/native/generator.rb in torch-rb-0.3.2 vs lib/torch/native/generator.rb in torch-rb-0.3.3
- old
+ new
@@ -31,34 +31,18 @@
todo_functions, functions =
functions.partition do |f|
f.args.any? do |a|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
skip_args.any? { |sa| a[:type].include?(sa) } ||
+ # call to 'range' is ambiguous
+ f.cpp_name == "_range" ||
# native_functions.yaml is missing size argument for normal
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
(f.base_name == "normal" && !f.out?)
end
end
- # generate additional functions for optional arguments
- # there may be a better way to do this
- optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
- optional_functions.each do |f|
- next if f.ruby_name == "cross"
- next if f.ruby_name.start_with?("avg_pool") && f.out?
-
- opt_args = f.args.select { |a| a[:type] == "int?" }
- if opt_args.size == 1
- sep = f.name.include?(".") ? "_" : "."
- f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
- # TODO only remove some arguments
- f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
- functions << f1
- functions << f2
- end
- end
-
# todo_functions.each do |f|
# puts f.func
# puts
# end
@@ -95,11 +79,12 @@
}
TEMPLATE
cpp_defs = []
functions.sort_by(&:cpp_name).each do |func|
- fargs = func.args #.select { |a| a[:type] != "Generator?" }
+ fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
+ fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
cpp_args = []
fargs.each do |a|
t =
case a[:type]
@@ -107,25 +92,29 @@
"const Tensor &"
when "Tensor?"
# TODO better signature
"OptionalTensor"
when "ScalarType?"
- "OptionalScalarType"
+ "torch::optional<ScalarType>"
when "Tensor[]"
"TensorList"
when "Tensor?[]"
# TODO make optional
"TensorList"
when "int"
"int64_t"
+ when "int?"
+ "torch::optional<int64_t>"
when "float"
"double"
when /\Aint\[/
"IntArrayRef"
when /Tensor\(\S!?\)/
"Tensor &"
when "str"
"std::string"
+ when "TensorOptions"
+ "const torch::TensorOptions &"
else
a[:type]
end
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"