lib/torch/tensor.rb in torch-rb-0.3.5 vs lib/torch/tensor.rb in torch-rb-0.3.6
- old
+ new
@@ -101,15 +101,10 @@
# unsure if this is correct
def new
Torch.empty(0, dtype: dtype)
end
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
- retain_graph = create_graph if retain_graph.nil?
- _backward(gradient, retain_graph, create_graph)
- end
-
# TODO read directly from memory
def numo
cls = Torch._dtype_to_numo[dtype]
raise Error, "Cannot convert #{dtype} to Numo" unless cls
cls.from_string(_data_str).reshape(*shape)
@@ -233,10 +228,10 @@
indexes.map do |index|
case index
when Integer
TensorIndex.integer(index)
when Range
- finish = index.end
+ finish = index.end || -1
if finish == -1 && !index.exclude_end?
finish = nil
else
finish += 1 unless index.exclude_end?
end