codegen/generate_functions.rb in torch-rb-0.9.2 vs codegen/generate_functions.rb in torch-rb-0.10.0
- old
+ new
@@ -12,10 +12,11 @@
generate_files("tensor", :define_method, functions[:tensor])
generate_files("nn", :define_singleton_method, functions[:nn])
generate_files("fft", :define_singleton_method, functions[:fft])
generate_files("linalg", :define_singleton_method, functions[:linalg])
generate_files("special", :define_singleton_method, functions[:special])
+ generate_files("sparse", :define_singleton_method, functions[:sparse])
end
def load_functions
path = File.expand_path("native_functions.yaml", __dir__)
YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
@@ -45,10 +46,11 @@
def group_functions(functions)
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
+ sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
torch_functions = other_functions.select { |f| f.variants.include?("function") }
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
if unexpected_functions.any?
@@ -60,11 +62,12 @@
torch: torch_functions,
tensor: tensor_functions,
nn: nn_functions,
linalg: linalg_functions,
fft: fft_functions,
- special: special_functions
+ special: special_functions,
+ sparse: sparse_functions
}
end
def generate_files(type, def_method, functions)
method_defs = []
@@ -134,9 +137,10 @@
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
+ ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
ruby_name = name if name.start_with?("__")
# cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "