lib/torch/tensor.rb in torch-rb-0.2.4 vs lib/torch/tensor.rb in torch-rb-0.2.5
- old
+ new
@@ -23,12 +23,21 @@
def to_s
inspect
end
+ # TODO make more performant
def to_a
- reshape_arr(_flat_data, shape)
+ arr = _flat_data
+ if shape.empty?
+ arr
+ else
+ shape[1..-1].reverse.each do |dim|
+ arr = arr.each_slice(dim)
+ end
+ arr.to_a
+ end
end
# TODO support dtype
def to(device, non_blocking: false, copy: false)
device = Device.new(device) if device.is_a?(String)
@@ -62,11 +71,11 @@
def item
if numel != 1
raise Error, "only one element tensors can be converted to Ruby scalars"
end
- _flat_data.first
+ to_a.first
end
def to_i
item.to_i
end
@@ -86,11 +95,11 @@
# TODO read directly from memory
def numo
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
- cls.cast(_flat_data).reshape(*shape)
+ cls.from_string(_data_str).reshape(*shape)
end
def new_ones(*size, **options)
Torch.ones_like(Torch.empty(*size), **options)
end
@@ -114,19 +123,10 @@
def view(*size)
size = size.first if size.size == 1 && size.first.is_a?(Array)
_view(size)
end
- # value and other are swapped for some methods
- def add!(value = 1, other)
- if other.is_a?(Numeric)
- _add__scalar(other, value)
- else
- _add__tensor(other, value)
- end
- end
-
def +(other)
add(other)
end
def -(other)
@@ -199,10 +199,21 @@
else
raise Error, "Unsupported index type: #{index.class.name}"
end
end
+ # native functions that need manually defined
+
+ # value and other are swapped for some methods
+ def add!(value = 1, other)
+ if other.is_a?(Numeric)
+ _add__scalar(other, value)
+ else
+ _add__tensor(other, value)
+ end
+ end
+
# native functions overlap, so need to handle manually
def random!(*args)
case args.size
when 1
_random__to(*args)
@@ -215,20 +226,8 @@
private
def copy_to(dst, src)
dst.copy!(src)
- end
-
- def reshape_arr(arr, dims)
- if dims.empty?
- arr
- else
- arr = arr.flatten
- dims[1..-1].reverse.each do |dim|
- arr = arr.each_slice(dim)
- end
- arr.to_a
- end
end
end
end