lib/torch/tensor.rb in torch-rb-0.17.0 vs lib/torch/tensor.rb in torch-rb-0.17.1
- old
+ new
@@ -130,12 +130,16 @@
Torch.empty(0, dtype: dtype)
end
# TODO read directly from memory
def numo
- cls = Torch._dtype_to_numo[dtype]
- raise Error, "Cannot convert #{dtype} to Numo" unless cls
- cls.from_string(_data_str).reshape(*shape)
+ if dtype == :bool
+ Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape)
+ else
+ cls = Torch._dtype_to_numo[dtype]
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
+ cls.from_string(_data_str).reshape(*shape)
+ end
end
def requires_grad=(requires_grad)
_requires_grad!(requires_grad)
end