lib/torch/native/generator.rb in torch-rb-0.1.5 vs lib/torch/native/generator.rb in torch-rb-0.1.6
- old
+ new
@@ -15,27 +15,38 @@
end
def grouped_functions
functions = functions()
- # remove functions
+ # skip functions
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
- skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"]
- functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) }
+ skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
+
+ # remove functions
+ functions.reject! do |f|
+ f.ruby_name.start_with?("_") ||
+ f.ruby_name.end_with?("_backward") ||
+ skip_binding.include?(f.ruby_name) ||
+ f.args.any? { |a| a[:type].include?("Dimname") }
+ end
+
+ # separate out into todo
todo_functions, functions =
functions.partition do |f|
f.args.any? do |a|
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
+ a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
skip_args.any? { |sa| a[:type].include?(sa) }
end
end
# generate additional functions for optional arguments
# there may be a better way to do this
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
optional_functions.each do |f|
- next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross"
+ next if f.ruby_name == "cross"
+ next if f.ruby_name.start_with?("avg_pool") && f.out?
+
opt_args = f.args.select { |a| a[:type] == "int?" }
if opt_args.size == 1
sep = f.name.include?(".") ? "_" : "."
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
# TODO only remove some arguments
@@ -83,21 +94,23 @@
}
TEMPLATE
cpp_defs = []
functions.sort_by(&:cpp_name).each do |func|
- fargs = func.args.select { |a| a[:type] != "Generator?" }
+ fargs = func.args #.select { |a| a[:type] != "Generator?" }
cpp_args = []
fargs.each do |a|
t =
case a[:type]
when "Tensor"
"const Tensor &"
when "Tensor?"
# TODO better signature
"OptionalTensor"
+ when "ScalarType?"
+ "OptionalScalarType"
when "Tensor[]"
"TensorList"
when "int"
"int64_t"
when "float"
@@ -119,13 +132,19 @@
args.unshift(*args.pop(func.out_size)) if func.out?
args.delete("self") if def_method == :define_method
prefix = def_method == :define_method ? "self." : "torch::"
+ body = "#{prefix}#{dispatch}(#{args.join(", ")})"
+ # TODO check type as well
+ if func.ret_size > 1
+ body = "wrap(#{body})"
+ end
+
cpp_defs << ".#{def_method}(
\"#{func.cpp_name}\",
*[](#{cpp_args.join(", ")}) {
- return #{prefix}#{dispatch}(#{args.join(", ")});
+ return #{body};
})"
end
hpp_contents = hpp_template % {type: type}
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}