ext/torch/ext.cpp in torch-rb-0.3.0 vs ext/torch/ext.cpp in torch-rb-0.3.1
- old
+ new
@@ -350,11 +350,11 @@
*[](Tensor& self, bool requires_grad) {
return self.set_requires_grad(requires_grad);
})
.define_method(
"_backward",
- *[](Tensor& self, Object gradient) {
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
+ return self.backward(gradient, create_graph, retain_graph);
})
.define_method(
"grad",
*[](Tensor& self) {
return self.grad();