lib/torch/tensor.rb in torch-rb-0.1.2 vs lib/torch/tensor.rb in torch-rb-0.1.3
- old
+ new
@@ -4,14 +4,14 @@
include Inspector
alias_method :requires_grad?, :requires_grad
def self.new(*size)
- if size.first.is_a?(Tensor)
+ if size.length == 1 && size.first.is_a?(Tensor)
size.first
else
- Torch.rand(*size)
+ Torch.empty(*size)
end
end
def dtype
dtype = ENUM_TO_DTYPE[_dtype]
@@ -29,10 +29,16 @@
def to_a
reshape_arr(_data, shape)
end
+ # TODO support dtype
+ def to(device, non_blocking: false, copy: false)
+ device = Device.new(device) if device.is_a?(String)
+ _to(device, _dtype, non_blocking, copy)
+ end
+
def size(dim = nil)
if dim
_size(dim)
else
shape
@@ -52,10 +58,15 @@
raise Error, "only one element tensors can be converted to Ruby scalars"
end
_data.first
end
+ # unsure if this is correct
+ def new
+ Torch.empty(0, dtype: dtype)
+ end
+
def backward(gradient = nil)
if gradient
_backward_gradient(gradient)
else
_backward
@@ -82,12 +93,29 @@
enum = DTYPE_TO_ENUM[dtype]
raise Error, "Unknown type: #{dtype}" unless enum
_type(enum)
end
+ def add!(value = 1, other)
+ if other.is_a?(Numeric)
+ _add_scalar!(other * value)
+ else
+ # need to use alpha for sparse tensors instead of multiplying
+ _add_alpha!(other, value)
+ end
+ end
+
+ def mul!(other)
+ if other.is_a?(Numeric)
+ _mul_scalar!(other)
+ else
+ _mul!(other)
+ end
+ end
+
# operations
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
+ %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
define_method(op) do |*args, **options, &block|
if options.any?
Torch.send(op, self, *args, **options, &block)
else
Torch.send(op, self, *args, &block)
@@ -125,14 +153,15 @@
def <=>(other)
item <=> other
end
+ # based on python_variable_indexing.cpp
def [](*indexes)
result = self
dim = 0
- indexes.each_with_index do |index|
+ indexes.each do |index|
if index.is_a?(Numeric)
result = result._select(dim, index)
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
@@ -142,9 +171,14 @@
raise Error, "Unsupported index type"
end
end
result
end
+
+ # TODO
+ # based on python_variable_indexing.cpp
+ # def []=(index, value)
+ # end
private
def reshape_arr(arr, dims)
if dims.empty?