lib/torch/native/generator.rb in torch-rb-0.1.8 vs lib/torch/native/generator.rb in torch-rb-0.2.0
- old
+ new
@@ -16,27 +16,28 @@
def grouped_functions
functions = functions()
# skip functions
- skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
# remove functions
functions.reject! do |f|
f.ruby_name.start_with?("_") ||
f.ruby_name.end_with?("_backward") ||
- skip_binding.include?(f.ruby_name) ||
f.args.any? { |a| a[:type].include?("Dimname") }
end
# separate out into todo
todo_functions, functions =
functions.partition do |f|
f.args.any? do |a|
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
- skip_args.any? { |sa| a[:type].include?(sa) }
+ skip_args.any? { |sa| a[:type].include?(sa) } ||
+ # native_functions.yaml is missing size argument for normal
+ # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
+ (f.base_name == "normal" && !f.out?)
end
end
# generate additional functions for optional arguments
# there may be a better way to do this
@@ -117,9 +118,11 @@
"double"
when /\Aint\[/
"IntArrayRef"
when /Tensor\(\S!?\)/
"Tensor &"
+ when "str"
+ "std::string"
else
a[:type]
end
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"