lib/torch/tensor.rb in torch-rb-0.2.7 vs lib/torch/tensor.rb in torch-rb-0.3.0
- old
+ new
@@ -45,14 +45,19 @@
end
arr.to_a
end
end
- # TODO support dtype
- def to(device, non_blocking: false, copy: false)
+ def to(device = nil, dtype: nil, non_blocking: false, copy: false)
+ device ||= self.device
device = Device.new(device) if device.is_a?(String)
- _to(device, _dtype, non_blocking, copy)
+
+ dtype ||= self.dtype
+ enum = DTYPE_TO_ENUM[dtype]
+ raise Error, "Unknown type: #{dtype}" unless enum
+
+ _to(device, enum, non_blocking, copy)
end
def cpu
to("cpu")
end
@@ -213,15 +218,15 @@
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
if index.is_a?(Numeric)
- copy_to(_select_int(0, index), value)
+ index_put!([Torch.tensor(index)], value)
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
- copy_to(_slice_tensor(0, index.begin, finish, 1), value)
+ _slice_tensor(0, index.begin, finish, 1).copy!(value)
elsif index.is_a?(Tensor)
index_put!([index], value)
else
raise Error, "Unsupported index type: #{index.class.name}"
end
@@ -251,14 +256,8 @@
end
def clamp!(min, max)
_clamp_min_(min)
_clamp_max_(max)
- end
-
- private
-
- def copy_to(dst, src)
- dst.copy!(src)
end
end
end