lib/torch/tensor.rb in torch-rb-0.3.7 vs lib/torch/tensor.rb in torch-rb-0.4.0

- old
+ new

@@ -6,10 +6,22 @@ alias_method :requires_grad?, :requires_grad alias_method :ndim, :dim alias_method :ndimension, :dim + # use alias_method for performance + alias_method :+, :add + alias_method :-, :sub + alias_method :*, :mul + alias_method :/, :div + alias_method :%, :remainder + alias_method :**, :pow + alias_method :-@, :neg + alias_method :&, :logical_and + alias_method :|, :logical_or + alias_method :^, :logical_xor + def self.new(*args) FloatTensor.new(*args) end def dtype @@ -71,16 +83,24 @@ to("cuda") end def size(dim = nil) if dim - _size_int(dim) + _size(dim) else shape end end + def stride(dim = nil) + if dim + _stride(dim) + else + _strides + end + end + # mirror Python len() def length size(0) end @@ -128,131 +148,31 @@ raise Error, "Invalid type: #{dtype}" unless enum _type(enum) end end - def reshape(*size) - # Python doesn't check if size == 1, just ignores later arguments - size = size.first if size.size == 1 && size.first.is_a?(Array) - _reshape(size) - end - - def view(*size) - size = size.first if size.size == 1 && size.first.is_a?(Array) - _view(size) - end - - def +(other) - add(other) - end - - def -(other) - sub(other) - end - - def *(other) - mul(other) - end - - def /(other) - div(other) - end - - def %(other) - remainder(other) - end - - def **(other) - pow(other) - end - - def -@ - neg - end - - def &(other) - logical_and(other) - end - - def |(other) - logical_or(other) - end - - def ^(other) - logical_xor(other) - end - # TODO better compare? def <=>(other) item <=> other end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html def [](*indexes) - _index(tensor_indexes(indexes)) + _index(indexes) end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html def []=(*indexes, value) raise ArgumentError, "Tensor does not support deleting items" if value.nil? value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) - _index_put_custom(tensor_indexes(indexes), value) + _index_put_custom(indexes, value) end - # native functions that need manually defined - - # value and other are swapped for some methods - def add!(value = 1, other) - if other.is_a?(Numeric) - _add__scalar(other, value) - else - _add__tensor(other, value) - end - end - # parser can't handle overlap, so need to handle manually def random!(*args) - case args.size - when 1 - _random__to(*args) - when 2 - _random__from(*args) - else - _random_(*args) - end - end - - def clamp!(min, max) - _clamp_min_(min) - _clamp_max_(max) - end - - private - - def tensor_indexes(indexes) - indexes.map do |index| - case index - when Integer - TensorIndex.integer(index) - when Range - finish = index.end || -1 - if finish == -1 && !index.exclude_end? - finish = nil - else - finish += 1 unless index.exclude_end? - end - TensorIndex.slice(index.begin, finish) - when Tensor - TensorIndex.tensor(index) - when nil - TensorIndex.none - when true, false - TensorIndex.boolean(index) - else - raise Error, "Unsupported index type: #{index.class.name}" - end - end + return _random!(0, *args) if args.size == 1 + _random!(*args) end end end