lib/torch/native/generator.rb in torch-rb-0.3.5 vs lib/torch/native/generator.rb in torch-rb-0.3.6
- old
+ new
@@ -16,24 +16,23 @@
def grouped_functions
functions = functions()
# skip functions
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
+ skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
# remove functions
functions.reject! do |f|
f.ruby_name.start_with?("_") ||
- f.ruby_name.end_with?("_backward") ||
+ f.ruby_name.include?("_backward") ||
f.args.any? { |a| a[:type].include?("Dimname") }
end
# separate out into todo
todo_functions, functions =
functions.partition do |f|
f.args.any? do |a|
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
skip_args.any? { |sa| a[:type].include?(sa) } ||
# call to 'range' is ambiguous
f.cpp_name == "_range" ||
# native_functions.yaml is missing size argument for normal
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
@@ -102,10 +101,16 @@
"TensorList"
when "int"
"int64_t"
when "int?"
"torch::optional<int64_t>"
+ when "float?"
+ "torch::optional<double>"
+ when "bool?"
+ "torch::optional<bool>"
+ when "Scalar?"
+ "torch::optional<torch::Scalar>"
when "float"
"double"
when /\Aint\[/
"IntArrayRef"
when /Tensor\(\S!?\)/
@@ -128,11 +133,11 @@
args.delete("self") if def_method == :define_method
prefix = def_method == :define_method ? "self." : "torch::"
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
- # TODO check type as well
- if func.ret_size > 1
+
+ if func.ret_size > 1 || func.ret_array?
body = "wrap(#{body})"
end
cpp_defs << ".#{def_method}(
\"#{func.cpp_name}\",