lib/torch/tensor.rb in torch-rb-0.1.2 vs lib/torch/tensor.rb in torch-rb-0.1.3

- old
+ new

@@ -4,14 +4,14 @@ include Inspector alias_method :requires_grad?, :requires_grad def self.new(*size) - if size.first.is_a?(Tensor) + if size.length == 1 && size.first.is_a?(Tensor) size.first else - Torch.rand(*size) + Torch.empty(*size) end end def dtype dtype = ENUM_TO_DTYPE[_dtype] @@ -29,10 +29,16 @@ def to_a reshape_arr(_data, shape) end + # TODO support dtype + def to(device, non_blocking: false, copy: false) + device = Device.new(device) if device.is_a?(String) + _to(device, _dtype, non_blocking, copy) + end + def size(dim = nil) if dim _size(dim) else shape @@ -52,10 +58,15 @@ raise Error, "only one element tensors can be converted to Ruby scalars" end _data.first end + # unsure if this is correct + def new + Torch.empty(0, dtype: dtype) + end + def backward(gradient = nil) if gradient _backward_gradient(gradient) else _backward @@ -82,12 +93,29 @@ enum = DTYPE_TO_ENUM[dtype] raise Error, "Unknown type: #{dtype}" unless enum _type(enum) end + def add!(value = 1, other) + if other.is_a?(Numeric) + _add_scalar!(other * value) + else + # need to use alpha for sparse tensors instead of multiplying + _add_alpha!(other, value) + end + end + + def mul!(other) + if other.is_a?(Numeric) + _mul_scalar!(other) + else + _mul!(other) + end + end + # operations - %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op| + %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op| define_method(op) do |*args, **options, &block| if options.any? Torch.send(op, self, *args, **options, &block) else Torch.send(op, self, *args, &block) @@ -125,14 +153,15 @@ def <=>(other) item <=> other end + # based on python_variable_indexing.cpp def [](*indexes) result = self dim = 0 - indexes.each_with_index do |index| + indexes.each do |index| if index.is_a?(Numeric) result = result._select(dim, index) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? @@ -142,9 +171,14 @@ raise Error, "Unsupported index type" end end result end + + # TODO + # based on python_variable_indexing.cpp + # def []=(index, value) + # end private def reshape_arr(arr, dims) if dims.empty?