codegen/generate_functions.rb in torch-rb-0.8.0 vs codegen/generate_functions.rb in torch-rb-0.8.1
- old
+ new
@@ -9,10 +9,13 @@
functions = group_functions(functions)
generate_files("torch", :define_singleton_method, functions[:torch])
generate_files("tensor", :define_method, functions[:tensor])
generate_files("nn", :define_singleton_method, functions[:nn])
+ generate_files("fft", :define_singleton_method, functions[:fft])
+ generate_files("linalg", :define_singleton_method, functions[:linalg])
+ generate_files("special", :define_singleton_method, functions[:special])
end
def load_functions
path = File.expand_path("native_functions.yaml", __dir__)
YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
@@ -36,14 +39,30 @@
end
end
def group_functions(functions)
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
+ linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
+ fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
+ special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
+ unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
torch_functions = other_functions.select { |f| f.variants.include?("function") }
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
+ if unexpected_functions.any?
+ unexpected_modules = unexpected_functions.map(&:python_module).uniq
+ raise "Unexpected modules: #{unexpected_modules.join(", ")}"
+ end
+
+ {
+ torch: torch_functions,
+ tensor: tensor_functions,
+ nn: nn_functions,
+ linalg: linalg_functions,
+ fft: fft_functions,
+ special: special_functions
+ }
end
def generate_files(type, def_method, functions)
method_defs = []
attach_defs = []
@@ -109,15 +128,18 @@
else
name
end
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
+ ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
+ ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
+ ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
# cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
- "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
end
def generate_method_def(name, functions, type, def_method)
assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);" : ""
@@ -126,11 +148,11 @@
max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
dispatches = add_dispatches(functions, def_method)
template = <<~EOS
// #{name}
- static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
+ static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_)
{
HANDLE_TH_ERRORS#{assign_self}
static RubyArgParser parser({
#{signatures.map(&:inspect).join(",\n ")}
});
@@ -557,6 +579,14 @@
end
type += "[#{param[:list_size]}]" if param[:list_size]
type += "?" if param[:optional]
type
+end
+
+def full_name(name, type)
+ if %w(fft linalg special).include?(type) && name.start_with?("#{type}_")
+ name
+ else
+ "#{type}_#{name}"
+ end
end