lib/torch/tensor.rb in torch-rb-0.3.0 vs lib/torch/tensor.rb in torch-rb-0.3.1
- old
+ new
@@ -101,11 +101,12 @@
# unsure if this is correct
def new
Torch.empty(0, dtype: dtype)
end
- def backward(gradient = nil)
- _backward(gradient)
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
+ retain_graph = create_graph if retain_graph.nil?
+ _backward(gradient, retain_graph, create_graph)
end
# TODO read directly from memory
def numo
cls = Torch._dtype_to_numo[dtype]