lib/torch/tensor.rb in torch-rb-0.2.5 vs lib/torch/tensor.rb in torch-rb-0.2.6

- old
+ new

@@ -156,11 +156,12 @@ # TODO better compare? def <=>(other) item <=> other end - # based on python_variable_indexing.cpp + # based on python_variable_indexing.cpp and + # https://pytorch.org/cppdocs/notes/tensor_indexing.html def [](*indexes) result = self dim = 0 indexes.each do |index| if index.is_a?(Numeric) @@ -168,10 +169,12 @@ elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? result = result._slice_tensor(dim, index.begin, finish, 1) dim += 1 + elsif index.is_a?(Tensor) + result = result.index([index]) elsif index.nil? result = result.unsqueeze(dim) dim += 1 elsif index == true result = result.unsqueeze(dim) @@ -181,23 +184,25 @@ end end result end - # TODO - # based on python_variable_indexing.cpp + # based on python_variable_indexing.cpp and + # https://pytorch.org/cppdocs/notes/tensor_indexing.html def []=(index, value) raise ArgumentError, "Tensor does not support deleting items" if value.nil? - value = Torch.tensor(value) unless value.is_a?(Tensor) + value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor) if index.is_a?(Numeric) copy_to(_select_int(0, index), value) elsif index.is_a?(Range) finish = index.end finish += 1 unless index.exclude_end? copy_to(_slice_tensor(0, index.begin, finish, 1), value) + elsif index.is_a?(Tensor) + index_put!([index], value) else raise Error, "Unsupported index type: #{index.class.name}" end end @@ -220,9 +225,14 @@ when 2 _random__from_to(*args) else _random_(*args) end + end + + def clamp!(min, max) + _clamp_min_(min) + _clamp_max_(max) end private def copy_to(dst, src)