lib/torch/tensor.rb in torch-rb-0.2.5 vs lib/torch/tensor.rb in torch-rb-0.2.6
- old
+ new
@@ -156,11 +156,12 @@
# TODO better compare?
def <=>(other)
item <=> other
end
- # based on python_variable_indexing.cpp
+ # based on python_variable_indexing.cpp and
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
def [](*indexes)
result = self
dim = 0
indexes.each do |index|
if index.is_a?(Numeric)
@@ -168,10 +169,12 @@
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
result = result._slice_tensor(dim, index.begin, finish, 1)
dim += 1
+ elsif index.is_a?(Tensor)
+ result = result.index([index])
elsif index.nil?
result = result.unsqueeze(dim)
dim += 1
elsif index == true
result = result.unsqueeze(dim)
@@ -181,23 +184,25 @@
end
end
result
end
- # TODO
- # based on python_variable_indexing.cpp
+ # based on python_variable_indexing.cpp and
+ # https://pytorch.org/cppdocs/notes/tensor_indexing.html
def []=(index, value)
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
- value = Torch.tensor(value) unless value.is_a?(Tensor)
+ value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
if index.is_a?(Numeric)
copy_to(_select_int(0, index), value)
elsif index.is_a?(Range)
finish = index.end
finish += 1 unless index.exclude_end?
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
+ elsif index.is_a?(Tensor)
+ index_put!([index], value)
else
raise Error, "Unsupported index type: #{index.class.name}"
end
end
@@ -220,9 +225,14 @@
when 2
_random__from_to(*args)
else
_random_(*args)
end
+ end
+
+ def clamp!(min, max)
+ _clamp_min_(min)
+ _clamp_max_(max)
end
private
def copy_to(dst, src)