lib/torch/tensor.rb in torch-rb-0.3.6 vs lib/torch/tensor.rb in torch-rb-0.3.7
- old
+ new
@@ -46,10 +46,15 @@
arr.to_a
end
end
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
+ if device.is_a?(Symbol) && !dtype
+ dtype = device
+ device = nil
+ end
+
device ||= self.device
device = Device.new(device) if device.is_a?(String)
dtype ||= self.dtype
enum = DTYPE_TO_ENUM[dtype]
@@ -72,14 +77,10 @@
else
shape
end
end
- def shape
- dim.times.map { |i| size(i) }
- end
-
# mirror Python len()
def length
size(0)
end
@@ -117,12 +118,17 @@
def requires_grad!(requires_grad = true)
_requires_grad!(requires_grad)
end
def type(dtype)
- enum = DTYPE_TO_ENUM[dtype]
- raise Error, "Unknown type: #{dtype}" unless enum
- _type(enum)
+ if dtype.is_a?(Class)
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
+ dtype.new(self)
+ else
+ enum = DTYPE_TO_ENUM[dtype]
+ raise Error, "Invalid type: #{dtype}" unless enum
+ _type(enum)
+ end
end
def reshape(*size)
# Python doesn't check if size == 1, just ignores later arguments
size = size.first if size.size == 1 && size.first.is_a?(Array)