lib/torch/native/generator.rb in torch-rb-0.2.5 vs lib/torch/native/generator.rb in torch-rb-0.2.6
- old
+ new
@@ -16,11 +16,11 @@
def grouped_functions
functions = functions()
# skip functions
- skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
+ skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
# remove functions
functions.reject! do |f|
f.ruby_name.start_with?("_") ||
f.ruby_name.end_with?("_backward") ||
@@ -29,11 +29,11 @@
# separate out into todo
todo_functions, functions =
functions.partition do |f|
f.args.any? do |a|
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
skip_args.any? { |sa| a[:type].include?(sa) } ||
# native_functions.yaml is missing size argument for normal
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
(f.base_name == "normal" && !f.out?)
end
@@ -109,9 +109,12 @@
# TODO better signature
"OptionalTensor"
when "ScalarType?"
"OptionalScalarType"
when "Tensor[]"
+ "TensorList"
+ when "Tensor?[]"
+ # TODO make optional
"TensorList"
when "int"
"int64_t"
when "float"
"double"