codegen/generate_functions.rb in torch-rb-0.10.2 vs codegen/generate_functions.rb in torch-rb-0.11.0

- old
+ new

@@ -37,11 +37,18 @@ f.base_name == "index_put_" || # need to add to ext.cpp f.base_name == "index_put" || # not supported yet f.func.include?("Dimname") || - f.func.include?("ConstQuantizerPtr") + f.func.include?("ConstQuantizerPtr") || + f.func.include?("SymInt") || + # TODO fix LibTorch 1.12 changes + f.base_name == "histogramdd" || + f.base_name == "nested_tensor" || + f.base_name == "split_copy" || + f.base_name == "split_with_sizes_copy" || + f.base_name == "unbind_copy" end end def group_functions(functions) nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } @@ -248,11 +255,11 @@ params, opt_params = split_opt_params(params) opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any? cpp_params = generate_dispatch_params(function, params) if opt_index - cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options") + cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options") end retval = generate_dispatch_retval(function) dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self) function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self) @@ -408,11 +415,11 @@ if function.out? "tensor" else "optionalTensor" end - when "generator", "tensorlist", "intlist" + when "generator", "tensorlist" func when "string" "stringViewOptional" else "#{func}Optional" @@ -469,26 +476,36 @@ when "int" "int64_t" when "float" "double" when /\Aint\[/ - "IntArrayRef" + if param[:optional] + "at::OptionalIntArrayRef" + else + "IntArrayRef" + end when "float[]" "ArrayRef<double>" when "str" if param[:optional] "c10::string_view" else "std::string" end - when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage" + when "Scalar" + if param[:optional] + "const c10::optional<Scalar> &" + else + "const Scalar &" + end + when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage" param[:type] else raise "Unknown type: #{param[:type]} (#{function.name})" end - if param[:optional] && param[:type] != "Tensor" + if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") type = "c10::optional<#{type}>" end "#{type} #{param[:name]}" end @@ -527,10 +544,10 @@ when ["Tensor", "Tensor", "float", "int"] "std::tuple<Tensor,Tensor,double,int>" when ["float", "float"] "std::tuple<double,double>" else - raise "Unknown retvals: #{types}" + raise "Unknown retvals: #{types} (#{function.name})" end end def generate_signature(function, type, skip_out: false) params = function.params.dup