lib/torch/tensor.rb in torch-rb-0.3.2 vs lib/torch/tensor.rb in torch-rb-0.3.3
- old
+ new
@@ -186,53 +186,19 @@
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
def [](*indexes)
- result = self
- dim = 0
- indexes.each do |index|
- if index.is_a?(Numeric)
- result = result._select_int(dim, index)
- elsif index.is_a?(Range)
- finish = index.end
- finish += 1 unless index.exclude_end?
- result = result._slice_tensor(dim, index.begin, finish, 1)
- dim += 1
- elsif index.is_a?(Tensor)
- result = result.index([index])
- elsif index.nil?
- result = result.unsqueeze(dim)
- dim += 1
- elsif index == true
- result = result.unsqueeze(dim)
- # TODO handle false
- else
- raise Error, "Unsupported index type: #{index.class.name}"
- end
- end
- result
+ _index(tensor_indexes(indexes))
end
# based on python_variable_indexing.cpp and
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
- def []=(index, value)
+ def []=(*indexes, value)
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
-
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
-
- if index.is_a?(Numeric)
- index_put!([Torch.tensor(index)], value)
- elsif index.is_a?(Range)
- finish = index.end
- finish += 1 unless index.exclude_end?
- _slice_tensor(0, index.begin, finish, 1).copy!(value)
- elsif index.is_a?(Tensor)
- index_put!([index], value)
- else
- raise Error, "Unsupported index type: #{index.class.name}"
- end
+ _index_put_custom(tensor_indexes(indexes), value)
end
# native functions that need manually defined
# value and other are swapped for some methods
@@ -242,23 +208,50 @@
else
_add__tensor(other, value)
end
end
- # native functions overlap, so need to handle manually
+ # parser can't handle overlap, so need to handle manually
def random!(*args)
case args.size
when 1
_random__to(*args)
when 2
- _random__from_to(*args)
+ _random__from(*args)
else
_random_(*args)
end
end
def clamp!(min, max)
_clamp_min_(min)
_clamp_max_(max)
+ end
+
+ private
+
+ def tensor_indexes(indexes)
+ indexes.map do |index|
+ case index
+ when Integer
+ TensorIndex.integer(index)
+ when Range
+ finish = index.end
+ if finish == -1 && !index.exclude_end?
+ finish = nil
+ else
+ finish += 1 unless index.exclude_end?
+ end
+ TensorIndex.slice(index.begin, finish)
+ when Tensor
+ TensorIndex.tensor(index)
+ when nil
+ TensorIndex.none
+ when true, false
+ TensorIndex.boolean(index)
+ else
+ raise Error, "Unsupported index type: #{index.class.name}"
+ end
+ end
end
end
end