lib/torch.rb in torch-rb-0.6.0 vs lib/torch.rb in torch-rb-0.7.0
- old
+ new
@@ -236,12 +236,15 @@
float: 6,
float32: 6,
double: 7,
float64: 7,
complex_half: 8,
+ complex32: 8,
complex_float: 9,
+ complex64: 9,
complex_double: 10,
+ complex128: 10,
bool: 11,
qint8: 12,
quint8: 13,
qint32: 14,
bfloat16: 15
@@ -392,9 +395,11 @@
if options[:dtype].nil?
if data.all? { |v| v.is_a?(Integer) }
options[:dtype] = :int64
elsif data.all? { |v| v == true || v == false }
options[:dtype] = :bool
+ elsif data.any? { |v| v.is_a?(Complex) }
+ options[:dtype] = :complex64
end
end
_tensor(data, size, tensor_options(**options))
end