codegen/generate_functions.rb in torch-rb-0.11.2 vs codegen/generate_functions.rb in torch-rb-0.12.0

- old
+ new

@@ -13,10 +13,11 @@ generate_files("nn", :define_singleton_method, functions[:nn]) generate_files("fft", :define_singleton_method, functions[:fft]) generate_files("linalg", :define_singleton_method, functions[:linalg]) generate_files("special", :define_singleton_method, functions[:special]) generate_files("sparse", :define_singleton_method, functions[:sparse]) + # TODO generate nested end def load_functions path = File.expand_path("native_functions.yaml", __dir__) YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name) @@ -38,26 +39,28 @@ # need to add to ext.cpp f.base_name == "index_put" || # not supported yet f.func.include?("Dimname") || f.func.include?("ConstQuantizerPtr") || - f.func.include?("SymInt") || # TODO fix LibTorch 1.12 changes f.base_name == "histogramdd" || f.base_name == "nested_tensor" || f.base_name == "split_copy" || f.base_name == "split_with_sizes_copy" || - f.base_name == "unbind_copy" + f.base_name == "unbind_copy" || + # TODO fix LibTorch 1.13 changes + f.base_name == "native_channel_shuffle" end end def group_functions(functions) nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" } fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" } special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" } sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" } + nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" } unexpected_functions, other_functions = other_functions.partition { |f| f.python_module } torch_functions = other_functions.select { |f| f.variants.include?("function") } tensor_functions = other_functions.select { |f| f.variants.include?("method") } if unexpected_functions.any? @@ -70,11 +73,12 @@ tensor: tensor_functions, nn: nn_functions, linalg: linalg_functions, fft: fft_functions, special: special_functions, - sparse: sparse_functions + sparse: sparse_functions, + nested: nested_functions } end def generate_files(type, def_method, functions) method_defs = [] @@ -385,18 +389,22 @@ "tensorlist" when "Scalar[]" "scalarlist" when /\Aint\[/ "intlist" + when /\ASymInt\[/ + "symintlist" when "float[]" "doublelist" when "Scalar" "scalar" when "bool" "toBool" when "int" "toInt64" + when "SymInt" + "toSymInt" when "float" "toDouble" when "ScalarType" "scalartype" when "str" @@ -435,11 +443,16 @@ def generate_dispatch_code(function, def_method, params, opt_index, remove_self) # torch::empty sets requires_grad by at::empty doesn't # https://github.com/pytorch/pytorch/issues/36455 prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::") - dispatch = function.out? ? "#{function.base_name}_out" : function.base_name + dispatch = function.dispatch_name + unless dispatch + dispatch = function.base_name + dispatch += "_symint" if function.func.include?("SymInt") + dispatch += "_out" if function.out? + end params = params.map { |v| v[:name] } params.reject! { |v| v == "self" } if remove_self params.insert(opt_index, "options") if opt_index @@ -476,18 +489,26 @@ "TensorList" when "Scalar[]" "ScalarList" when "int" "int64_t" + when "SymInt" + "c10::SymInt" when "float" "double" when /\Aint\[/ if param[:optional] "at::OptionalIntArrayRef" else "IntArrayRef" end + when /\ASymInt\[/ + if param[:optional] + "at::OptionalSymIntArrayRef" + else + "c10::SymIntArrayRef" + end when "float[]" "ArrayRef<double>" when "str" if param[:optional] "c10::string_view" @@ -504,11 +525,11 @@ param[:type] else raise "Unknown type: #{param[:type]} (#{function.name})" end - if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") + if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[") type = "c10::optional<#{type}>" end "#{type} #{param[:name]}" end @@ -612,11 +633,15 @@ "ScalarList" when /\ADimname\[\d*\]\z/ "DirnameList" when /\Aint\[\d*\]\z/ "IntArrayRef" + when /\ASymInt\[\d*\]\z/ + "SymIntArrayRef" when "int" "int64_t" + when "SymInt" + "c10::SymInt" when "float" "double" when "str" "std::string" when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"