lib/torch/native/generator.rb in torch-rb-0.3.6 vs lib/torch/native/generator.rb in torch-rb-0.3.7

- old
+ new

@@ -70,20 +70,22 @@ #include <torch/torch.h> #include <rice/Module.hpp> #include "templates.hpp" +%{functions} + void add_%{type}_functions(Module m) { - m - %{functions}; + %{add_functions} } TEMPLATE cpp_defs = [] + add_defs = [] functions.sort_by(&:cpp_name).each do |func| fargs = func.args.dup #.select { |a| a[:type] != "Generator?" } - fargs << {name: "options", type: "TensorOptions"} if func.tensor_options + fargs << {name: :options, type: "TensorOptions"} if func.tensor_options cpp_args = [] fargs.each do |a| t = case a[:type] @@ -92,15 +94,13 @@ when "Tensor?" # TODO better signature "OptionalTensor" when "ScalarType?" "torch::optional<ScalarType>" - when "Tensor[]" - "TensorList" - when "Tensor?[]" + when "Tensor[]", "Tensor?[]" # TODO make optional - "TensorList" + "std::vector<Tensor>" when "int" "int64_t" when "int?" "torch::optional<int64_t>" when "float?" @@ -110,46 +110,56 @@ when "Scalar?" "torch::optional<torch::Scalar>" when "float" "double" when /\Aint\[/ - "IntArrayRef" + "std::vector<int64_t>" when /Tensor\(\S!?\)/ "Tensor &" when "str" "std::string" when "TensorOptions" "const torch::TensorOptions &" - else + when "Layout?" + "torch::optional<Layout>" + when "Device?" + "torch::optional<Device>" + when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage" a[:type] + else + raise "Unknown type: #{a[:type]}" end - t = "MyReduction" if a[:name] == "reduction" && t == "int64_t" + t = "MyReduction" if a[:name] == :reduction && t == "int64_t" cpp_args << [t, a[:name]].join(" ").sub("& ", "&") end dispatch = func.out? ? "#{func.base_name}_out" : func.base_name args = fargs.map { |a| a[:name] } args.unshift(*args.pop(func.out_size)) if func.out? - args.delete("self") if def_method == :define_method + args.delete(:self) if def_method == :define_method prefix = def_method == :define_method ? "self." : "torch::" body = "#{prefix}#{dispatch}(#{args.join(", ")})" - if func.ret_size > 1 || func.ret_array? + if func.cpp_name == "_fill_diagonal_" + body = "to_ruby<torch::Tensor>(#{body})" + elsif !func.ret_void? body = "wrap(#{body})" end - cpp_defs << ".#{def_method}( - \"#{func.cpp_name}\", - *[](#{cpp_args.join(", ")}) { - return #{body}; - })" + cpp_defs << "// #{func.func} +static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")}) +{ + return #{body}; +}" + + add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});" end hpp_contents = hpp_template % {type: type} - cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")} + cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")} path = File.expand_path("../../../ext/torch", __dir__) File.write("#{path}/#{type}_functions.hpp", hpp_contents) File.write("#{path}/#{type}_functions.cpp", cpp_contents) end