lib/torch/tensor.rb in torch-rb-0.1.1 vs lib/torch/tensor.rb in torch-rb-0.1.2

- old
+ new

@@ -26,11 +26,11 @@ def to_s inspect end def to_a - reshape(_data, shape) + reshape_arr(_data, shape) end def size(dim = nil) if dim _size(dim) @@ -52,12 +52,16 @@ raise Error, "only one element tensors can be converted to Ruby scalars" end _data.first end - def data - Torch.tensor(to_a) + def backward(gradient = nil) + if gradient + _backward_gradient(gradient) + else + _backward + end end # TODO read directly from memory def numo raise Error, "Numo not found" unless defined?(Numo::NArray) @@ -72,12 +76,18 @@ def requires_grad!(requires_grad = true) _requires_grad!(requires_grad) end + def type(dtype) + enum = DTYPE_TO_ENUM[dtype] + raise Error, "Unknown type: #{dtype}" unless enum + _type(enum) + end + # operations - %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze).each do |op| + %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op| define_method(op) do |*args, **options, &block| if options.any? Torch.send(op, self, *args, **options, &block) else Torch.send(op, self, *args, &block) @@ -115,21 +125,30 @@ def <=>(other) item <=> other end - # TODO use accessor C++ method - def [](index, *args) - v = _access(index) - args.each do |i| - v = v._access(i) + def [](*indexes) + result = self + dim = 0 + indexes.each_with_index do |index| + if index.is_a?(Numeric) + result = result._select(dim, index) + elsif index.is_a?(Range) + finish = index.end + finish += 1 unless index.exclude_end? + result = result._slice(dim, index.begin, finish, 1) + dim += 1 + else + raise Error, "Unsupported index type" + end end - v + result end private - def reshape(arr, dims) + def reshape_arr(arr, dims) if dims.empty? arr else arr = arr.flatten dims[1..-1].reverse.each do |dim|