ext/torch/ext.cpp in torch-rb-0.3.0 vs ext/torch/ext.cpp in torch-rb-0.3.1

- old
+ new

@@ -350,11 +350,11 @@ *[](Tensor& self, bool requires_grad) { return self.set_requires_grad(requires_grad); }) .define_method( "_backward", - *[](Tensor& self, Object gradient) { - return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient)); + *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) { + return self.backward(gradient, create_graph, retain_graph); }) .define_method( "grad", *[](Tensor& self) { return self.grad();