lib/torch/tensor.rb in torch-rb-0.3.7 vs lib/torch/tensor.rb in torch-rb-0.4.0
- old
+ new
@@ -6,10 +6,22 @@
alias_method :requires_grad?, :requires_grad
alias_method :ndim, :dim
alias_method :ndimension, :dim
+ # use alias_method for performance
+ alias_method :+, :add
+ alias_method :-, :sub
+ alias_method :*, :mul
+ alias_method :/, :div
+ alias_method :%, :remainder
+ alias_method :**, :pow
+ alias_method :-@, :neg
+ alias_method :&, :logical_and
+ alias_method :|, :logical_or
+ alias_method :^, :logical_xor
+
def self.new(*args)
FloatTensor.new(*args)
end
def dtype
@@ -71,16 +83,24 @@
to("cuda")
end
def size(dim = nil)
if dim
- _size_int(dim)
+ _size(dim)
else
shape
end
end
+ def stride(dim = nil)
+ if dim
+ _stride(dim)
+ else
+ _strides
+ end
+ end
+
# mirror Python len()
def length
size(0)
end
@@ -128,131 +148,31 @@
raise Error, "Invalid type: #{dtype}" unless enum
_type(enum)
end
end
- def reshape(*size)
- # Python doesn't check if size == 1, just ignores later arguments
- size = size.first if size.size == 1 && size.first.is_a?(Array)
- _reshape(size)
- end
-
- def view(*size)
- size = size.first if size.size == 1 && size.first.is_a?(Array)
- _view(size)
- end
-
- def +(other)
- add(other)
- end
-
- def -(other)
- sub(other)
- end
-
- def *(other)
- mul(other)
- end
-
- def /(other)
- div(other)
- end
-
- def %(other)
- remainder(other)
- end
-
- def **(other)
- pow(other)
- end
-
- def -@
- neg
- end
-
- def &(other)
- logical_and(other)
- end
-
- def |(other)
- logical_or(other)
- end
-
- def ^(other)
- logical_xor(other)
- end
-
# TODO better compare?
def <=>(other)
item <=> other
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
def [](*indexes)
- _index(tensor_indexes(indexes))
+ _index(indexes)
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
def []=(*indexes, value)
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
- _index_put_custom(tensor_indexes(indexes), value)
+ _index_put_custom(indexes, value)
end
- # native functions that need manually defined
-
- # value and other are swapped for some methods
- def add!(value = 1, other)
- if other.is_a?(Numeric)
- _add__scalar(other, value)
- else
- _add__tensor(other, value)
- end
- end
-
# parser can't handle overlap, so need to handle manually
def random!(*args)
- case args.size
- when 1
- _random__to(*args)
- when 2
- _random__from(*args)
- else
- _random_(*args)
- end
- end
-
- def clamp!(min, max)
- _clamp_min_(min)
- _clamp_max_(max)
- end
-
- private
-
- def tensor_indexes(indexes)
- indexes.map do |index|
- case index
- when Integer
- TensorIndex.integer(index)
- when Range
- finish = index.end || -1
- if finish == -1 && !index.exclude_end?
- finish = nil
- else
- finish += 1 unless index.exclude_end?
- end
- TensorIndex.slice(index.begin, finish)
- when Tensor
- TensorIndex.tensor(index)
- when nil
- TensorIndex.none
- when true, false
- TensorIndex.boolean(index)
- else
- raise Error, "Unsupported index type: #{index.class.name}"
- end
- end
+ return _random!(0, *args) if args.size == 1
+ _random!(*args)
end
end
end