codegen/generate_functions.rb in torch-rb-0.11.2 vs codegen/generate_functions.rb in torch-rb-0.12.0
- old
+ new
@@ -13,10 +13,11 @@
generate_files("nn", :define_singleton_method, functions[:nn])
generate_files("fft", :define_singleton_method, functions[:fft])
generate_files("linalg", :define_singleton_method, functions[:linalg])
generate_files("special", :define_singleton_method, functions[:special])
generate_files("sparse", :define_singleton_method, functions[:sparse])
+ # TODO generate nested
end
def load_functions
path = File.expand_path("native_functions.yaml", __dir__)
YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
@@ -38,26 +39,28 @@
# need to add to ext.cpp
f.base_name == "index_put" ||
# not supported yet
f.func.include?("Dimname") ||
f.func.include?("ConstQuantizerPtr") ||
- f.func.include?("SymInt") ||
# TODO fix LibTorch 1.12 changes
f.base_name == "histogramdd" ||
f.base_name == "nested_tensor" ||
f.base_name == "split_copy" ||
f.base_name == "split_with_sizes_copy" ||
- f.base_name == "unbind_copy"
+ f.base_name == "unbind_copy" ||
+ # TODO fix LibTorch 1.13 changes
+ f.base_name == "native_channel_shuffle"
end
end
def group_functions(functions)
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" }
fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" }
special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" }
sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" }
+ nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" }
unexpected_functions, other_functions = other_functions.partition { |f| f.python_module }
torch_functions = other_functions.select { |f| f.variants.include?("function") }
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
if unexpected_functions.any?
@@ -70,11 +73,12 @@
tensor: tensor_functions,
nn: nn_functions,
linalg: linalg_functions,
fft: fft_functions,
special: special_functions,
- sparse: sparse_functions
+ sparse: sparse_functions,
+ nested: nested_functions
}
end
def generate_files(type, def_method, functions)
method_defs = []
@@ -385,18 +389,22 @@
"tensorlist"
when "Scalar[]"
"scalarlist"
when /\Aint\[/
"intlist"
+ when /\ASymInt\[/
+ "symintlist"
when "float[]"
"doublelist"
when "Scalar"
"scalar"
when "bool"
"toBool"
when "int"
"toInt64"
+ when "SymInt"
+ "toSymInt"
when "float"
"toDouble"
when "ScalarType"
"scalartype"
when "str"
@@ -435,11 +443,16 @@
def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
# torch::empty sets requires_grad by at::empty doesn't
# https://github.com/pytorch/pytorch/issues/36455
prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
- dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
+ dispatch = function.dispatch_name
+ unless dispatch
+ dispatch = function.base_name
+ dispatch += "_symint" if function.func.include?("SymInt")
+ dispatch += "_out" if function.out?
+ end
params = params.map { |v| v[:name] }
params.reject! { |v| v == "self" } if remove_self
params.insert(opt_index, "options") if opt_index
@@ -476,18 +489,26 @@
"TensorList"
when "Scalar[]"
"ScalarList"
when "int"
"int64_t"
+ when "SymInt"
+ "c10::SymInt"
when "float"
"double"
when /\Aint\[/
if param[:optional]
"at::OptionalIntArrayRef"
else
"IntArrayRef"
end
+ when /\ASymInt\[/
+ if param[:optional]
+ "at::OptionalSymIntArrayRef"
+ else
+ "c10::SymIntArrayRef"
+ end
when "float[]"
"ArrayRef<double>"
when "str"
if param[:optional]
"c10::string_view"
@@ -504,11 +525,11 @@
param[:type]
else
raise "Unknown type: #{param[:type]} (#{function.name})"
end
- if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[")
type = "c10::optional<#{type}>"
end
"#{type} #{param[:name]}"
end
@@ -612,11 +633,15 @@
"ScalarList"
when /\ADimname\[\d*\]\z/
"DirnameList"
when /\Aint\[\d*\]\z/
"IntArrayRef"
+ when /\ASymInt\[\d*\]\z/
+ "SymIntArrayRef"
when "int"
"int64_t"
+ when "SymInt"
+ "c10::SymInt"
when "float"
"double"
when "str"
"std::string"
when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"