codegen/generate_functions.rb in torch-rb-0.10.2 vs codegen/generate_functions.rb in torch-rb-0.11.0
- old
+ new
@@ -37,11 +37,18 @@
f.base_name == "index_put_" ||
# need to add to ext.cpp
f.base_name == "index_put" ||
# not supported yet
f.func.include?("Dimname") ||
- f.func.include?("ConstQuantizerPtr")
+ f.func.include?("ConstQuantizerPtr") ||
+ f.func.include?("SymInt") ||
+ # TODO fix LibTorch 1.12 changes
+ f.base_name == "histogramdd" ||
+ f.base_name == "nested_tensor" ||
+ f.base_name == "split_copy" ||
+ f.base_name == "split_with_sizes_copy" ||
+ f.base_name == "unbind_copy"
end
end
def group_functions(functions)
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
@@ -248,11 +255,11 @@
params, opt_params = split_opt_params(params)
opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
cpp_params = generate_dispatch_params(function, params)
if opt_index
- cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options")
end
retval = generate_dispatch_retval(function)
dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
@@ -408,11 +415,11 @@
if function.out?
"tensor"
else
"optionalTensor"
end
- when "generator", "tensorlist", "intlist"
+ when "generator", "tensorlist"
func
when "string"
"stringViewOptional"
else
"#{func}Optional"
@@ -469,26 +476,36 @@
when "int"
"int64_t"
when "float"
"double"
when /\Aint\[/
- "IntArrayRef"
+ if param[:optional]
+ "at::OptionalIntArrayRef"
+ else
+ "IntArrayRef"
+ end
when "float[]"
"ArrayRef<double>"
when "str"
if param[:optional]
"c10::string_view"
else
"std::string"
end
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
+ when "Scalar"
+ if param[:optional]
+ "const c10::optional<Scalar> &"
+ else
+ "const Scalar &"
+ end
+ when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
param[:type]
else
raise "Unknown type: #{param[:type]} (#{function.name})"
end
- if param[:optional] && param[:type] != "Tensor"
+ if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[")
type = "c10::optional<#{type}>"
end
"#{type} #{param[:name]}"
end
@@ -527,10 +544,10 @@
when ["Tensor", "Tensor", "float", "int"]
"std::tuple<Tensor,Tensor,double,int>"
when ["float", "float"]
"std::tuple<double,double>"
else
- raise "Unknown retvals: #{types}"
+ raise "Unknown retvals: #{types} (#{function.name})"
end
end
def generate_signature(function, type, skip_out: false)
params = function.params.dup