lib/torch/tensor.rb in torch-rb-0.1.5 vs lib/torch/tensor.rb in torch-rb-0.1.6
- old
+ new
@@ -43,12 +43,13 @@
def shape
dim.times.map { |i| size(i) }
end
- def view(*size)
- _view(size)
+ # mirror Python len()
+ def length
+ size(0)
end
def item
if numel != 1
raise Error, "only one element tensors can be converted to Ruby scalars"
@@ -84,40 +85,28 @@
enum = DTYPE_TO_ENUM[dtype]
raise Error, "Unknown type: #{dtype}" unless enum
_type(enum)
end
- # start temp operations
+ def reshape(*size)
+ # Python doesn't check if size == 1, just ignores later arguments
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
+ _reshape(size)
+ end
+ def view(*size)
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
+ _view(size)
+ end
+
+ # value and other are swapped for some methods
def add!(value = 1, other)
if other.is_a?(Numeric)
_add__scalar(other, value)
else
- # need to use alpha for sparse tensors instead of multiplying
_add__tensor(other, value)
end
end
-
- def mul!(other)
- if other.is_a?(Numeric)
- _mul__scalar(other)
- else
- _mul__tensor(other)
- end
- end
-
- # operations
- %w(log_softmax mean softmax sum topk).each do |op|
- define_method(op) do |*args, **options, &block|
- if options.any?
- Torch.send(op, self, *args, **options, &block)
- else
- Torch.send(op, self, *args, &block)
- end
- end
- end
-
- # end temp operations
def +(other)
add(other)
end