lib/torch/tensor.rb in torch-rb-0.3.0 vs lib/torch/tensor.rb in torch-rb-0.3.1

- old
+ new

@@ -101,11 +101,12 @@ # unsure if this is correct def new Torch.empty(0, dtype: dtype) end - def backward(gradient = nil) - _backward(gradient) + def backward(gradient = nil, retain_graph: nil, create_graph: false) + retain_graph = create_graph if retain_graph.nil? + _backward(gradient, retain_graph, create_graph) end # TODO read directly from memory def numo cls = Torch._dtype_to_numo[dtype]