lib/torch/tensor.rb in torch-rb-0.1.5 vs lib/torch/tensor.rb in torch-rb-0.1.6

- old
+ new

@@ -43,12 +43,13 @@ def shape dim.times.map { |i| size(i) } end - def view(*size) - _view(size) + # mirror Python len() + def length + size(0) end def item if numel != 1 raise Error, "only one element tensors can be converted to Ruby scalars" @@ -84,40 +85,28 @@ enum = DTYPE_TO_ENUM[dtype] raise Error, "Unknown type: #{dtype}" unless enum _type(enum) end - # start temp operations + def reshape(*size) + # Python doesn't check if size == 1, just ignores later arguments + size = size.first if size.size == 1 && size.first.is_a?(Array) + _reshape(size) + end + def view(*size) + size = size.first if size.size == 1 && size.first.is_a?(Array) + _view(size) + end + + # value and other are swapped for some methods def add!(value = 1, other) if other.is_a?(Numeric) _add__scalar(other, value) else - # need to use alpha for sparse tensors instead of multiplying _add__tensor(other, value) end end - - def mul!(other) - if other.is_a?(Numeric) - _mul__scalar(other) - else - _mul__tensor(other) - end - end - - # operations - %w(log_softmax mean softmax sum topk).each do |op| - define_method(op) do |*args, **options, &block| - if options.any? - Torch.send(op, self, *args, **options, &block) - else - Torch.send(op, self, *args, &block) - end - end - end - - # end temp operations def +(other) add(other) end