codegen/generate_functions.rb in torch-rb-0.8.3 vs codegen/generate_functions.rb in torch-rb-0.9.0
- old
+ new
@@ -26,10 +26,13 @@
(f.base_name.start_with?("_") && f.base_name != "__lshift__" && f.base_name != "__rshift__") ||
f.base_name.include?("_backward") ||
f.base_name.include?("_forward") ||
f.base_name == "to" ||
f.base_name == "record_stream" ||
+ f.base_name == "is_pinned" ||
+ f.base_name == "pin_memory" ||
+ f.base_name == "fused_moving_avg_obs_fake_quant" ||
# in ext.cpp
f.base_name == "index" ||
f.base_name == "index_put_" ||
# need to add to ext.cpp
f.base_name == "index_put" ||
@@ -385,10 +388,12 @@
else
"optionalTensor"
end
when "generator", "tensorlist", "intlist"
func
+ when "string"
+ "stringViewOptional"
else
"#{func}Optional"
end
end
@@ -422,13 +427,11 @@
when "Tensor"
if param[:optional]
if function.out?
"const Tensor &"
else
- # TODO
- # "const c10::optional<at::Tensor> &"
- "const OptionalTensor &"
+ "const c10::optional<at::Tensor> &"
end
elsif param[:modifier]
if param[:modifier].include?("!") && function.retvals.size > 1
"Tensor &"
else
@@ -448,10 +451,14 @@
when /\Aint\[/
"IntArrayRef"
when "float[]"
"ArrayRef<double>"
when "str"
- "std::string"
+ if param[:optional]
+ "c10::string_view"
+ else
+ "std::string"
+ end
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
param[:type]
else
raise "Unknown type: #{param[:type]} (#{function.name})"
end