lib/torch/tensor.rb in torch-rb-0.3.6 vs lib/torch/tensor.rb in torch-rb-0.3.7

- old
+ new

@@ -46,10 +46,15 @@ arr.to_a end end def to(device = nil, dtype: nil, non_blocking: false, copy: false) + if device.is_a?(Symbol) && !dtype + dtype = device + device = nil + end + device ||= self.device device = Device.new(device) if device.is_a?(String) dtype ||= self.dtype enum = DTYPE_TO_ENUM[dtype] @@ -72,14 +77,10 @@ else shape end end - def shape - dim.times.map { |i| size(i) } - end - # mirror Python len() def length size(0) end @@ -117,12 +118,17 @@ def requires_grad!(requires_grad = true) _requires_grad!(requires_grad) end def type(dtype) - enum = DTYPE_TO_ENUM[dtype] - raise Error, "Unknown type: #{dtype}" unless enum - _type(enum) + if dtype.is_a?(Class) + raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype) + dtype.new(self) + else + enum = DTYPE_TO_ENUM[dtype] + raise Error, "Invalid type: #{dtype}" unless enum + _type(enum) + end end def reshape(*size) # Python doesn't check if size == 1, just ignores later arguments size = size.first if size.size == 1 && size.first.is_a?(Array)