lib/torch/tensor.rb in torch-rb-0.1.3 vs lib/torch/tensor.rb in torch-rb-0.1.4

- old
+ new

@@ -64,20 +64,15 @@ def new Torch.empty(0, dtype: dtype) end def backward(gradient = nil) - if gradient - _backward_gradient(gradient) - else - _backward - end + _backward(gradient) end # TODO read directly from memory def numo - raise Error, "Numo not found" unless defined?(Numo::NArray) cls = Torch._dtype_to_numo[dtype] raise Error, "Cannot convert #{dtype} to Numo" unless cls cls.cast(_data).reshape(*shape) end @@ -111,11 +106,11 @@ _mul!(other) end end # operations - %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op| + %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op| define_method(op) do |*args, **options, &block| if options.any? Torch.send(op, self, *args, **options, &block) else Torch.send(op, self, *args, &block) @@ -165,22 +160,45 @@ elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? result = result._slice(dim, index.begin, finish, 1) dim += 1 + elsif index.nil? + result = result.unsqueeze(dim) + dim += 1 + elsif index == true + result = result.unsqueeze(dim) + # TODO handle false else - raise Error, "Unsupported index type" + raise Error, "Unsupported index type: #{index.class.name}" end end result end # TODO # based on python_variable_indexing.cpp - # def []=(index, value) - # end + def []=(index, value) + raise ArgumentError, "Tensor does not support deleting items" if value.nil? + value = Torch.tensor(value) unless value.is_a?(Tensor) + + if index.is_a?(Numeric) + copy_to(_select(0, index), value) + elsif index.is_a?(Range) + finish = index.end + finish += 1 unless index.exclude_end? + copy_to(_slice(0, index.begin, finish, 1), value) + else + raise Error, "Unsupported index type: #{index.class.name}" + end + end + private + + def copy_to(dst, src) + dst.copy!(src) + end def reshape_arr(arr, dims) if dims.empty? arr else