lib/torch/tensor.rb in torch-rb-0.16.0 vs lib/torch/tensor.rb in torch-rb-0.17.0
- old
+ new
@@ -55,11 +55,11 @@
def to_a
arr = _flat_data
if shape.empty?
arr
else
- shape[1..-1].reverse.each do |dim|
+ shape[1..-1].reverse_each do |dim|
arr = arr.each_slice(dim)
end
arr.to_a
end
end
@@ -158,16 +158,18 @@
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
def [](*indexes)
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
_index(indexes)
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
def []=(*indexes, value)
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
+ indexes = indexes.map { |v| v.is_a?(Array) ? Torch.tensor(v) : v }
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
_index_put_custom(indexes, value)
end
# parser can't handle overlap, so need to handle manually