lib/torch/tensor.rb in torch-rb-0.17.0 vs lib/torch/tensor.rb in torch-rb-0.17.1

- old
+ new

@@ -130,12 +130,16 @@ Torch.empty(0, dtype: dtype) end # TODO read directly from memory def numo - cls = Torch._dtype_to_numo[dtype] - raise Error, "Cannot convert #{dtype} to Numo" unless cls - cls.from_string(_data_str).reshape(*shape) + if dtype == :bool + Numo::UInt8.from_string(_data_str).ne(0).reshape(*shape) + else + cls = Torch._dtype_to_numo[dtype] + raise Error, "Cannot convert #{dtype} to Numo" unless cls + cls.from_string(_data_str).reshape(*shape) + end end def requires_grad=(requires_grad) _requires_grad!(requires_grad) end