lib/torch/tensor.rb in torch-rb-0.2.4 vs lib/torch/tensor.rb in torch-rb-0.2.5

- old
+ new

@@ -23,12 +23,21 @@ def to_s inspect end + # TODO make more performant def to_a - reshape_arr(_flat_data, shape) + arr = _flat_data + if shape.empty? + arr + else + shape[1..-1].reverse.each do |dim| + arr = arr.each_slice(dim) + end + arr.to_a + end end # TODO support dtype def to(device, non_blocking: false, copy: false) device = Device.new(device) if device.is_a?(String) @@ -62,11 +71,11 @@ def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" end - _flat_data.first + to_a.first end def to_i item.to_i end @@ -86,11 +95,11 @@ # TODO read directly from memory def numo cls = Torch._dtype_to_numo[dtype] raise Error, "Cannot convert #{dtype} to Numo" unless cls - cls.cast(_flat_data).reshape(*shape) + cls.from_string(_data_str).reshape(*shape) end def new_ones(*size, **options) Torch.ones_like(Torch.empty(*size), **options) end @@ -114,19 +123,10 @@ def view(*size) size = size.first if size.size == 1 && size.first.is_a?(Array) _view(size) end - # value and other are swapped for some methods - def add!(value = 1, other) - if other.is_a?(Numeric) - _add__scalar(other, value) - else - _add__tensor(other, value) - end - end - def +(other) add(other) end def -(other) @@ -199,10 +199,21 @@ else raise Error, "Unsupported index type: #{index.class.name}" end end + # native functions that need manually defined + + # value and other are swapped for some methods + def add!(value = 1, other) + if other.is_a?(Numeric) + _add__scalar(other, value) + else + _add__tensor(other, value) + end + end + # native functions overlap, so need to handle manually def random!(*args) case args.size when 1 _random__to(*args) @@ -215,20 +226,8 @@ private def copy_to(dst, src) dst.copy!(src) - end - - def reshape_arr(arr, dims) - if dims.empty? - arr - else - arr = arr.flatten - dims[1..-1].reverse.each do |dim| - arr = arr.each_slice(dim) - end - arr.to_a - end end end end