lib/torch/native/generator.rb in torch-rb-0.1.8 vs lib/torch/native/generator.rb in torch-rb-0.2.0

- old
+ new

@@ -16,27 +16,28 @@ def grouped_functions functions = functions() # skip functions - skip_binding = ["unique_dim_consecutive", "einsum", "normal"] skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"] # remove functions functions.reject! do |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || - skip_binding.include?(f.ruby_name) || f.args.any? { |a| a[:type].include?("Dimname") } end # separate out into todo todo_functions, functions = functions.partition do |f| f.args.any? do |a| a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) || - skip_args.any? { |sa| a[:type].include?(sa) } + skip_args.any? { |sa| a[:type].include?(sa) } || + # native_functions.yaml is missing size argument for normal + # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html + (f.base_name == "normal" && !f.out?) end end # generate additional functions for optional arguments # there may be a better way to do this @@ -117,9 +118,11 @@ "double" when /\Aint\[/ "IntArrayRef" when /Tensor\(\S!?\)/ "Tensor &" + when "str" + "std::string" else a[:type] end t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"