lib/torch/native/generator.rb in torch-rb-0.2.5 vs lib/torch/native/generator.rb in torch-rb-0.2.6

- old
+ new

@@ -16,11 +16,11 @@ def grouped_functions functions = functions() # skip functions - skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"] + skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"] # remove functions functions.reject! do |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || @@ -29,11 +29,11 @@ # separate out into todo todo_functions, functions = functions.partition do |f| f.args.any? do |a| - a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) || + a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) || skip_args.any? { |sa| a[:type].include?(sa) } || # native_functions.yaml is missing size argument for normal # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html (f.base_name == "normal" && !f.out?) end @@ -109,9 +109,12 @@ # TODO better signature "OptionalTensor" when "ScalarType?" "OptionalScalarType" when "Tensor[]" + "TensorList" + when "Tensor?[]" + # TODO make optional "TensorList" when "int" "int64_t" when "float" "double"