lib/torch/tensor.rb in torch-rb-0.2.7 vs lib/torch/tensor.rb in torch-rb-0.3.0

- old
+ new

@@ -45,14 +45,19 @@ end arr.to_a end end - # TODO support dtype - def to(device, non_blocking: false, copy: false) + def to(device = nil, dtype: nil, non_blocking: false, copy: false) + device ||= self.device device = Device.new(device) if device.is_a?(String) - _to(device, _dtype, non_blocking, copy) + + dtype ||= self.dtype + enum = DTYPE_TO_ENUM[dtype] + raise Error, "Unknown type: #{dtype}" unless enum + + _to(device, enum, non_blocking, copy) end def cpu to("cpu") end @@ -213,15 +218,15 @@ raise ArgumentError, "Tensor does not support deleting items" if value.nil? value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) if index.is_a?(Numeric) - copy_to(_select_int(0, index), value) + index_put!([Torch.tensor(index)], value) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? - copy_to(_slice_tensor(0, index.begin, finish, 1), value) + _slice_tensor(0, index.begin, finish, 1).copy!(value) elsif index.is_a?(Tensor) index_put!([index], value) else raise Error, "Unsupported index type: #{index.class.name}" end @@ -251,14 +256,8 @@ end def clamp!(min, max) _clamp_min_(min) _clamp_max_(max) - end - - private - - def copy_to(dst, src) - dst.copy!(src) end end end