codegen/generate_functions.rb in torch-rb-0.8.3 vs codegen/generate_functions.rb in torch-rb-0.9.0

- old
+ new

@@ -26,10 +26,13 @@ (f.base_name.start_with?("_") && f.base_name != "__lshift__" && f.base_name != "__rshift__") || f.base_name.include?("_backward") || f.base_name.include?("_forward") || f.base_name == "to" || f.base_name == "record_stream" || + f.base_name == "is_pinned" || + f.base_name == "pin_memory" || + f.base_name == "fused_moving_avg_obs_fake_quant" || # in ext.cpp f.base_name == "index" || f.base_name == "index_put_" || # need to add to ext.cpp f.base_name == "index_put" || @@ -385,10 +388,12 @@ else "optionalTensor" end when "generator", "tensorlist", "intlist" func + when "string" + "stringViewOptional" else "#{func}Optional" end end @@ -422,13 +427,11 @@ when "Tensor" if param[:optional] if function.out? "const Tensor &" else - # TODO - # "const c10::optional<at::Tensor> &" - "const OptionalTensor &" + "const c10::optional<at::Tensor> &" end elsif param[:modifier] if param[:modifier].include?("!") && function.retvals.size > 1 "Tensor &" else @@ -448,10 +451,14 @@ when /\Aint\[/ "IntArrayRef" when "float[]" "ArrayRef<double>" when "str" - "std::string" + if param[:optional] + "c10::string_view" + else + "std::string" + end when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage" param[:type] else raise "Unknown type: #{param[:type]} (#{function.name})" end