codegen/generate_functions.rb in torch-rb-0.9.2 vs codegen/generate_functions.rb in torch-rb-0.10.0

- old
+ new

@@ -12,10 +12,11 @@ generate_files("tensor", :define_method, functions[:tensor]) generate_files("nn", :define_singleton_method, functions[:nn]) generate_files("fft", :define_singleton_method, functions[:fft]) generate_files("linalg", :define_singleton_method, functions[:linalg]) generate_files("special", :define_singleton_method, functions[:special]) + generate_files("sparse", :define_singleton_method, functions[:sparse]) end def load_functions path = File.expand_path("native_functions.yaml", __dir__) YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name) @@ -45,10 +46,11 @@ def group_functions(functions) nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" } fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" } special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" } + sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" } unexpected_functions, other_functions = other_functions.partition { |f| f.python_module } torch_functions = other_functions.select { |f| f.variants.include?("function") } tensor_functions = other_functions.select { |f| f.variants.include?("method") } if unexpected_functions.any? @@ -60,11 +62,12 @@ torch: torch_functions, tensor: tensor_functions, nn: nn_functions, linalg: linalg_functions, fft: fft_functions, - special: special_functions + special: special_functions, + sparse: sparse_functions } end def generate_files(type, def_method, functions) method_defs = [] @@ -134,9 +137,10 @@ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name) ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft" ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg" ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special" + ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse" ruby_name = name if name.start_with?("__") # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900 cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "