codegen/generate_functions.rb in torch-rb-0.5.3 vs codegen/generate_functions.rb in torch-rb-0.6.0

- old
+ new

@@ -22,10 +22,11 @@ functions.reject do |f| f.base_name.start_with?("_") || f.base_name.include?("_backward") || f.base_name.include?("_forward") || f.base_name == "to" || + f.base_name == "record_stream" || # in ext.cpp f.base_name == "index" || f.base_name == "index_put_" || # need to add to ext.cpp f.base_name == "index_put" || @@ -59,11 +60,11 @@ // generated by rake generate:functions // do not edit by hand #pragma once - void add_%{type}_functions(Module m); + void add_%{type}_functions(Rice::Module& m); EOS contents = template % {type: type} write_file("#{type}_functions.h", contents) end @@ -79,11 +80,11 @@ #include "ruby_arg_parser.h" #include "templates.h" #include "wrap_outputs.h" %{method_defs} - void add_%{type}_functions(Module m) { + void add_%{type}_functions(Rice::Module& m) { %{attach_defs} } EOS contents = template % { @@ -121,21 +122,22 @@ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : "" functions = group_overloads(functions, type) signatures = functions.map { |f| f["signature"] } max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1 + dispatches = add_dispatches(functions, def_method) template = <<~EOS // #{name} static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_) { HANDLE_TH_ERRORS#{assign_self} static RubyArgParser parser({ #{signatures.map(&:inspect).join(",\n ")} }); - std::vector<VALUE> parsed_args(#{max_args}); - auto _r = parser.parse(self_, argc, argv, parsed_args); - #{add_dispatches(functions, def_method)} + ParsedArgs<#{max_args}> parsed_args; + #{dispatches.include?("_r.") ? "auto _r = " : ""}parser.parse(self_, argc, argv, parsed_args); + #{dispatches} END_HANDLE_TH_ERRORS } EOS end