codegen/generate_functions.rb in torch-rb-0.4.2 vs codegen/generate_functions.rb in torch-rb-0.5.0
- old
+ new
@@ -326,10 +326,12 @@
"tensor"
when "Tensor[]"
"tensorlist"
when /\Aint\[/
"intlist"
+ when "float[]"
+ "doublelist"
when "Scalar"
"scalar"
when "bool"
"toBool"
when "int"
@@ -417,10 +419,12 @@
"int64_t"
when "float"
"double"
when /\Aint\[/
"IntArrayRef"
+ when "float[]"
+ "ArrayRef<double>"
when "str"
"std::string"
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
param[:type]
else
@@ -464,11 +468,13 @@
when ["Tensor", "Tensor", "Tensor", "Tensor"]
"std::tuple<Tensor,Tensor,Tensor,Tensor>"
when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
"std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
when ["Tensor", "Tensor", "float", "int"]
- "std::tuple<Tensor,Tensor,float,int>"
+ "std::tuple<Tensor,Tensor,double,int>"
+ when ["float", "float"]
+ "std::tuple<double,double>"
else
raise "Unknown retvals: #{types}"
end
end
@@ -537,9 +543,11 @@
"double"
when "str"
"std::string"
when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
param[:type]
+ when "float[]"
+ "ArrayRef<double>"
else
raise "Unknown type: #{param[:type]}"
end
type += "[#{param[:list_size]}]" if param[:list_size]