lib/torch/tensor.rb in torch-rb-0.1.4 vs lib/torch/tensor.rb in torch-rb-0.1.5

- old
+ new

@@ -3,16 +3,12 @@ include Comparable include Inspector alias_method :requires_grad?, :requires_grad - def self.new(*size) - if size.length == 1 && size.first.is_a?(Tensor) - size.first - else - Torch.empty(*size) - end + def self.new(*args) + FloatTensor.new(*args) end def dtype dtype = ENUM_TO_DTYPE[_dtype] raise Error, "Unknown type: #{_dtype}" unless dtype @@ -26,22 +22,22 @@ def to_s inspect end def to_a - reshape_arr(_data, shape) + reshape_arr(_flat_data, shape) end # TODO support dtype def to(device, non_blocking: false, copy: false) device = Device.new(device) if device.is_a?(String) _to(device, _dtype, non_blocking, copy) end def size(dim = nil) if dim - _size(dim) + _size_int(dim) else shape end end @@ -55,11 +51,11 @@ def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" end - _data.first + _flat_data.first end # unsure if this is correct def new Torch.empty(0, dtype: dtype) @@ -71,11 +67,11 @@ # TODO read directly from memory def numo cls = Torch._dtype_to_numo[dtype] raise Error, "Cannot convert #{dtype} to Numo" unless cls - cls.cast(_data).reshape(*shape) + cls.cast(_flat_data).reshape(*shape) end def new_ones(*size, **options) Torch.ones_like(Torch.empty(*size), **options) end @@ -88,38 +84,42 @@ enum = DTYPE_TO_ENUM[dtype] raise Error, "Unknown type: #{dtype}" unless enum _type(enum) end + # start temp operations + def add!(value = 1, other) if other.is_a?(Numeric) - _add_scalar!(other * value) + _add__scalar(other, value) else # need to use alpha for sparse tensors instead of multiplying - _add_alpha!(other, value) + _add__tensor(other, value) end end def mul!(other) if other.is_a?(Numeric) - _mul_scalar!(other) + _mul__scalar(other) else - _mul!(other) + _mul__tensor(other) end end # operations - %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op| + %w(log_softmax mean softmax sum topk).each do |op| define_method(op) do |*args, **options, &block| if options.any? Torch.send(op, self, *args, **options, &block) else Torch.send(op, self, *args, &block) end end end + # end temp operations + def +(other) add(other) end def -(other) @@ -154,15 +154,15 @@ def [](*indexes) result = self dim = 0 indexes.each do |index| if index.is_a?(Numeric) - result = result._select(dim, index) + result = result._select_int(dim, index) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? - result = result._slice(dim, index.begin, finish, 1) + result = result._slice_tensor(dim, index.begin, finish, 1) dim += 1 elsif index.nil? result = result.unsqueeze(dim) dim += 1 elsif index == true @@ -181,14 +181,14 @@ raise ArgumentError, "Tensor does not support deleting items" if value.nil? value = Torch.tensor(value) unless value.is_a?(Tensor) if index.is_a?(Numeric) - copy_to(_select(0, index), value) + copy_to(_select_int(0, index), value) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? - copy_to(_slice(0, index.begin, finish, 1), value) + copy_to(_slice_tensor(0, index.begin, finish, 1), value) else raise Error, "Unsupported index type: #{index.class.name}" end end