codegen/generate_functions.rb in torch-rb-0.5.3 vs codegen/generate_functions.rb in torch-rb-0.6.0
- old
+ new
@@ -22,10 +22,11 @@
functions.reject do |f|
f.base_name.start_with?("_") ||
f.base_name.include?("_backward") ||
f.base_name.include?("_forward") ||
f.base_name == "to" ||
+ f.base_name == "record_stream" ||
# in ext.cpp
f.base_name == "index" ||
f.base_name == "index_put_" ||
# need to add to ext.cpp
f.base_name == "index_put" ||
@@ -59,11 +60,11 @@
// generated by rake generate:functions
// do not edit by hand
#pragma once
- void add_%{type}_functions(Module m);
+ void add_%{type}_functions(Rice::Module& m);
EOS
contents = template % {type: type}
write_file("#{type}_functions.h", contents)
end
@@ -79,11 +80,11 @@
#include "ruby_arg_parser.h"
#include "templates.h"
#include "wrap_outputs.h"
%{method_defs}
- void add_%{type}_functions(Module m) {
+ void add_%{type}_functions(Rice::Module& m) {
%{attach_defs}
}
EOS
contents = template % {
@@ -121,21 +122,22 @@
assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
functions = group_overloads(functions, type)
signatures = functions.map { |f| f["signature"] }
max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
+ dispatches = add_dispatches(functions, def_method)
template = <<~EOS
// #{name}
static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
{
HANDLE_TH_ERRORS#{assign_self}
static RubyArgParser parser({
#{signatures.map(&:inspect).join(",\n ")}
});
- std::vector<VALUE> parsed_args(#{max_args});
- auto _r = parser.parse(self_, argc, argv, parsed_args);
- #{add_dispatches(functions, def_method)}
+ ParsedArgs<#{max_args}> parsed_args;
+ #{dispatches.include?("_r.") ? "auto _r = " : ""}parser.parse(self_, argc, argv, parsed_args);
+ #{dispatches}
END_HANDLE_TH_ERRORS
}
EOS
end