codegen/generate_functions.rb in torch-rb-0.4.2 vs codegen/generate_functions.rb in torch-rb-0.5.0

- old
+ new

@@ -326,10 +326,12 @@ "tensor" when "Tensor[]" "tensorlist" when /\Aint\[/ "intlist" + when "float[]" + "doublelist" when "Scalar" "scalar" when "bool" "toBool" when "int" @@ -417,10 +419,12 @@ "int64_t" when "float" "double" when /\Aint\[/ "IntArrayRef" + when "float[]" + "ArrayRef<double>" when "str" "std::string" when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage" param[:type] else @@ -464,11 +468,13 @@ when ["Tensor", "Tensor", "Tensor", "Tensor"] "std::tuple<Tensor,Tensor,Tensor,Tensor>" when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"] "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>" when ["Tensor", "Tensor", "float", "int"] - "std::tuple<Tensor,Tensor,float,int>" + "std::tuple<Tensor,Tensor,double,int>" + when ["float", "float"] + "std::tuple<double,double>" else raise "Unknown retvals: #{types}" end end @@ -537,9 +543,11 @@ "double" when "str" "std::string" when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage" param[:type] + when "float[]" + "ArrayRef<double>" else raise "Unknown type: #{param[:type]}" end type += "[#{param[:list_size]}]" if param[:list_size]