lib/torch/native/generator.rb in torch-rb-0.3.6 vs lib/torch/native/generator.rb in torch-rb-0.3.7
- old
+ new
@@ -70,20 +70,22 @@
#include <torch/torch.h>
#include <rice/Module.hpp>
#include "templates.hpp"
+%{functions}
+
void add_%{type}_functions(Module m) {
- m
- %{functions};
+ %{add_functions}
}
TEMPLATE
cpp_defs = []
+ add_defs = []
functions.sort_by(&:cpp_name).each do |func|
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
- fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
+ fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
cpp_args = []
fargs.each do |a|
t =
case a[:type]
@@ -92,15 +94,13 @@
when "Tensor?"
# TODO better signature
"OptionalTensor"
when "ScalarType?"
"torch::optional<ScalarType>"
- when "Tensor[]"
- "TensorList"
- when "Tensor?[]"
+ when "Tensor[]", "Tensor?[]"
# TODO make optional
- "TensorList"
+ "std::vector<Tensor>"
when "int"
"int64_t"
when "int?"
"torch::optional<int64_t>"
when "float?"
@@ -110,46 +110,56 @@
when "Scalar?"
"torch::optional<torch::Scalar>"
when "float"
"double"
when /\Aint\[/
- "IntArrayRef"
+ "std::vector<int64_t>"
when /Tensor\(\S!?\)/
"Tensor &"
when "str"
"std::string"
when "TensorOptions"
"const torch::TensorOptions &"
- else
+ when "Layout?"
+ "torch::optional<Layout>"
+ when "Device?"
+ "torch::optional<Device>"
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
a[:type]
+ else
+ raise "Unknown type: #{a[:type]}"
end
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
+ t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
end
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
args = fargs.map { |a| a[:name] }
args.unshift(*args.pop(func.out_size)) if func.out?
- args.delete("self") if def_method == :define_method
+ args.delete(:self) if def_method == :define_method
prefix = def_method == :define_method ? "self." : "torch::"
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
- if func.ret_size > 1 || func.ret_array?
+ if func.cpp_name == "_fill_diagonal_"
+ body = "to_ruby<torch::Tensor>(#{body})"
+ elsif !func.ret_void?
body = "wrap(#{body})"
end
- cpp_defs << ".#{def_method}(
- \"#{func.cpp_name}\",
- *[](#{cpp_args.join(", ")}) {
- return #{body};
- })"
+ cpp_defs << "// #{func.func}
+static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
+{
+ return #{body};
+}"
+
+ add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
end
hpp_contents = hpp_template % {type: type}
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
+ cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
path = File.expand_path("../../../ext/torch", __dir__)
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
end