lib/torch/inspector.rb in torch-rb-0.14.1 vs lib/torch/inspector.rb in torch-rb-0.15.0
- old
+ new
@@ -29,13 +29,13 @@
# no valid number, do nothing
return if nonzero_finite_vals.numel == 0
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
- nonzero_finite_abs = nonzero_finite_vals.abs.double
- nonzero_finite_min = nonzero_finite_abs.min.double
- nonzero_finite_max = nonzero_finite_abs.max.double
+ nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
+ nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
+ nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
nonzero_finite_vals.each do |value|
if value.item != value.item.ceil
@int_mode = false
break
@@ -104,9 +104,14 @@
else
ret = value.to_s
end
# Ruby throws error when negative, Python doesn't
" " * [@max_width - ret.size, 0].max + ret
+ end
+
+ def tensor_totype(t)
+ dtype = t.mps? ? :float : :double
+ t.to(dtype: dtype)
end
end
def inspect
Torch.no_grad do