lib/torch/tensor.rb in torch-rb-0.1.1 vs lib/torch/tensor.rb in torch-rb-0.1.2
- old
+ new
@@ -26,11 +26,11 @@
def to_s
inspect
end
def to_a
- reshape(_data, shape)
+ reshape_arr(_data, shape)
end
def size(dim = nil)
if dim
_size(dim)
@@ -52,12 +52,16 @@
raise Error, "only one element tensors can be converted to Ruby scalars"
end
_data.first
end
- def data
- Torch.tensor(to_a)
+ def backward(gradient = nil)
+ if gradient
+ _backward_gradient(gradient)
+ else
+ _backward
+ end
end
# TODO read directly from memory
def numo
raise Error, "Numo not found" unless defined?(Numo::NArray)
@@ -72,12 +76,18 @@
def requires_grad!(requires_grad = true)
_requires_grad!(requires_grad)
end
+ def type(dtype)
+ enum = DTYPE_TO_ENUM[dtype]
+ raise Error, "Unknown type: #{dtype}" unless enum
+ _type(enum)
+ end
+
# operations
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op|
+ %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
define_method(op) do |*args, **options, &block|
if options.any?
Torch.send(op, self, *args, **options, &block)
else
Torch.send(op, self, *args, &block)
@@ -115,21 +125,30 @@
def <=>(other)
item <=> other
end
- # TODO use accessor C++ method
- def [](index, *args)
- v = _access(index)
- args.each do |i|
- v = v._access(i)
+ def [](*indexes)
+ result = self
+ dim = 0
+ indexes.each_with_index do |index|
+ if index.is_a?(Numeric)
+ result = result._select(dim, index)
+ elsif index.is_a?(Range)
+ finish = index.end
+ finish += 1 unless index.exclude_end?
+ result = result._slice(dim, index.begin, finish, 1)
+ dim += 1
+ else
+ raise Error, "Unsupported index type"
+ end
end
- v
+ result
end
private
- def reshape(arr, dims)
+ def reshape_arr(arr, dims)
if dims.empty?
arr
else
arr = arr.flatten
dims[1..-1].reverse.each do |dim|