codegen/generate_functions.rb in torch-rb-0.8.0 vs codegen/generate_functions.rb in torch-rb-0.8.1

- old
+ new

@@ -9,10 +9,13 @@ functions = group_functions(functions) generate_files("torch", :define_singleton_method, functions[:torch]) generate_files("tensor", :define_method, functions[:tensor]) generate_files("nn", :define_singleton_method, functions[:nn]) + generate_files("fft", :define_singleton_method, functions[:fft]) + generate_files("linalg", :define_singleton_method, functions[:linalg]) + generate_files("special", :define_singleton_method, functions[:special]) end def load_functions path = File.expand_path("native_functions.yaml", __dir__) YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name) @@ -36,14 +39,30 @@ end end def group_functions(functions) nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } + linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" } + fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" } + special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" } + unexpected_functions, other_functions = other_functions.partition { |f| f.python_module } torch_functions = other_functions.select { |f| f.variants.include?("function") } tensor_functions = other_functions.select { |f| f.variants.include?("method") } - {torch: torch_functions, tensor: tensor_functions, nn: nn_functions} + if unexpected_functions.any? + unexpected_modules = unexpected_functions.map(&:python_module).uniq + raise "Unexpected modules: #{unexpected_modules.join(", ")}" + end + + { + torch: torch_functions, + tensor: tensor_functions, + nn: nn_functions, + linalg: linalg_functions, + fft: fft_functions, + special: special_functions + } end def generate_files(type, def_method, functions) method_defs = [] attach_defs = [] @@ -109,15 +128,18 @@ else name end ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name) + ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft" + ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg" + ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special" # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900 cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) " - "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);" + "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);" end def generate_method_def(name, functions, type, def_method) assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);" : "" @@ -126,11 +148,11 @@ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1 dispatches = add_dispatches(functions, def_method) template = <<~EOS // #{name} - static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_) + static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_) { HANDLE_TH_ERRORS#{assign_self} static RubyArgParser parser({ #{signatures.map(&:inspect).join(",\n ")} }); @@ -557,6 +579,14 @@ end type += "[#{param[:list_size]}]" if param[:list_size] type += "?" if param[:optional] type +end + +def full_name(name, type) + if %w(fft linalg special).include?(type) && name.start_with?("#{type}_") + name + else + "#{type}_#{name}" + end end