lib/torch/tensor.rb in torch-rb-0.3.2 vs lib/torch/tensor.rb in torch-rb-0.3.3

- old
+ new

@@ -186,53 +186,19 @@ end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html def [](*indexes) - result = self - dim = 0 - indexes.each do |index| - if index.is_a?(Numeric) - result = result._select_int(dim, index) - elsif index.is_a?(Range) - finish = index.end - finish += 1 unless index.exclude_end? - result = result._slice_tensor(dim, index.begin, finish, 1) - dim += 1 - elsif index.is_a?(Tensor) - result = result.index([index]) - elsif index.nil? - result = result.unsqueeze(dim) - dim += 1 - elsif index == true - result = result.unsqueeze(dim) - # TODO handle false - else - raise Error, "Unsupported index type: #{index.class.name}" - end - end - result + _index(tensor_indexes(indexes)) end # based on python_variable_indexing.cpp and # https://pytorch.org/cppdocs/notes/tensor_indexing.html - def []=(index, value) + def []=(*indexes, value) raise ArgumentError, "Tensor does not support deleting items" if value.nil? - value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) - - if index.is_a?(Numeric) - index_put!([Torch.tensor(index)], value) - elsif index.is_a?(Range) - finish = index.end - finish += 1 unless index.exclude_end? - _slice_tensor(0, index.begin, finish, 1).copy!(value) - elsif index.is_a?(Tensor) - index_put!([index], value) - else - raise Error, "Unsupported index type: #{index.class.name}" - end + _index_put_custom(tensor_indexes(indexes), value) end # native functions that need manually defined # value and other are swapped for some methods @@ -242,23 +208,50 @@ else _add__tensor(other, value) end end - # native functions overlap, so need to handle manually + # parser can't handle overlap, so need to handle manually def random!(*args) case args.size when 1 _random__to(*args) when 2 - _random__from_to(*args) + _random__from(*args) else _random_(*args) end end def clamp!(min, max) _clamp_min_(min) _clamp_max_(max) + end + + private + + def tensor_indexes(indexes) + indexes.map do |index| + case index + when Integer + TensorIndex.integer(index) + when Range + finish = index.end + if finish == -1 && !index.exclude_end? + finish = nil + else + finish += 1 unless index.exclude_end? + end + TensorIndex.slice(index.begin, finish) + when Tensor + TensorIndex.tensor(index) + when nil + TensorIndex.none + when true, false + TensorIndex.boolean(index) + else + raise Error, "Unsupported index type: #{index.class.name}" + end + end end end end