lib/torch/tensor.rb in torch-rb-0.1.3 vs lib/torch/tensor.rb in torch-rb-0.1.4
- old
+ new
@@ -64,20 +64,15 @@
def new
Torch.empty(0, dtype: dtype)
end
def backward(gradient = nil)
- if gradient
- _backward_gradient(gradient)
- else
- _backward
- end
+ _backward(gradient)
end
# TODO read directly from memory
def numo
- raise Error, "Numo not found" unless defined?(Numo::NArray)
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
cls.cast(_data).reshape(*shape)
end
@@ -111,11 +106,11 @@
_mul!(other)
end
end
# operations
- %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
+ %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
define_method(op) do |*args, **options, &block|
if options.any?
Torch.send(op, self, *args, **options, &block)
else
Torch.send(op, self, *args, &block)
@@ -165,22 +160,45 @@
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
result = result._slice(dim, index.begin, finish, 1)
dim += 1
+ elsif index.nil?
+ result = result.unsqueeze(dim)
+ dim += 1
+ elsif index == true
+ result = result.unsqueeze(dim)
+ # TODO handle false
else
- raise Error, "Unsupported index type"
+ raise Error, "Unsupported index type: #{index.class.name}"
end
end
result
end
# TODO
# based on python_variable_indexing.cpp
- # def []=(index, value)
- # end
+ def []=(index, value)
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
+
+ if index.is_a?(Numeric)
+ copy_to(_select(0, index), value)
+ elsif index.is_a?(Range)
+ finish = index.end
+ finish += 1 unless index.exclude_end?
+ copy_to(_slice(0, index.begin, finish, 1), value)
+ else
+ raise Error, "Unsupported index type: #{index.class.name}"
+ end
+ end
+
private
+
+ def copy_to(dst, src)
+ dst.copy!(src)
+ end
def reshape_arr(arr, dims)
if dims.empty?
arr
else