lib/torch/tensor.rb in torch-rb-0.1.4 vs lib/torch/tensor.rb in torch-rb-0.1.5
- old
+ new
@@ -3,16 +3,12 @@
include Comparable
include Inspector
alias_method :requires_grad?, :requires_grad
- def self.new(*size)
- if size.length == 1 && size.first.is_a?(Tensor)
- size.first
- else
- Torch.empty(*size)
- end
+ def self.new(*args)
+ FloatTensor.new(*args)
end
def dtype
dtype = ENUM_TO_DTYPE[_dtype]
raise Error, "Unknown type: #{_dtype}" unless dtype
@@ -26,22 +22,22 @@
def to_s
inspect
end
def to_a
- reshape_arr(_data, shape)
+ reshape_arr(_flat_data, shape)
end
# TODO support dtype
def to(device, non_blocking: false, copy: false)
device = Device.new(device) if device.is_a?(String)
_to(device, _dtype, non_blocking, copy)
end
def size(dim = nil)
if dim
- _size(dim)
+ _size_int(dim)
else
shape
end
end
@@ -55,11 +51,11 @@
def item
if numel != 1
raise Error, "only one element tensors can be converted to Ruby scalars"
end
- _data.first
+ _flat_data.first
end
# unsure if this is correct
def new
Torch.empty(0, dtype: dtype)
@@ -71,11 +67,11 @@
# TODO read directly from memory
def numo
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
- cls.cast(_data).reshape(*shape)
+ cls.cast(_flat_data).reshape(*shape)
end
def new_ones(*size, **options)
Torch.ones_like(Torch.empty(*size), **options)
end
@@ -88,38 +84,42 @@
enum = DTYPE_TO_ENUM[dtype]
raise Error, "Unknown type: #{dtype}" unless enum
_type(enum)
end
+ # start temp operations
+
def add!(value = 1, other)
if other.is_a?(Numeric)
- _add_scalar!(other * value)
+ _add__scalar(other, value)
else
# need to use alpha for sparse tensors instead of multiplying
- _add_alpha!(other, value)
+ _add__tensor(other, value)
end
end
def mul!(other)
if other.is_a?(Numeric)
- _mul_scalar!(other)
+ _mul__scalar(other)
else
- _mul!(other)
+ _mul__tensor(other)
end
end
# operations
- %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
+ %w(log_softmax mean softmax sum topk).each do |op|
define_method(op) do |*args, **options, &block|
if options.any?
Torch.send(op, self, *args, **options, &block)
else
Torch.send(op, self, *args, &block)
end
end
end
+ # end temp operations
+
def +(other)
add(other)
end
def -(other)
@@ -154,15 +154,15 @@
def [](*indexes)
result = self
dim = 0
indexes.each do |index|
if index.is_a?(Numeric)
- result = result._select(dim, index)
+ result = result._select_int(dim, index)
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
- result = result._slice(dim, index.begin, finish, 1)
+ result = result._slice_tensor(dim, index.begin, finish, 1)
dim += 1
elsif index.nil?
result = result.unsqueeze(dim)
dim += 1
elsif index == true
@@ -181,14 +181,14 @@
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
value = Torch.tensor(value) unless value.is_a?(Tensor)
if index.is_a?(Numeric)
- copy_to(_select(0, index), value)
+ copy_to(_select_int(0, index), value)
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
- copy_to(_slice(0, index.begin, finish, 1), value)
+ copy_to(_slice_tensor(0, index.begin, finish, 1), value)
else
raise Error, "Unsupported index type: #{index.class.name}"
end
end