lib/torch/tensor.rb in torch-rb-0.16.0 vs lib/torch/tensor.rb in torch-rb-0.17.0

- old
+ new

@@ -55,11 +55,11 @@ def to_a arr = _flat_data if shape.empty? arr else - shape[1..-1].reverse.each do |dim| + shape[1..-1].reverse_each do |dim| arr = arr.each_slice(dim) end arr.to_a end end @@ -158,16 +158,18 @@ end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html def [](*indexes) + indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v } _index(indexes) end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html def []=(*indexes, value) raise ArgumentError, "Tensor does not support deleting items" if value.nil? + indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v } value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) _index_put_custom(indexes, value) end # parser can't handle overlap, so need to handle manually