ext/torch/ext.cpp in torch-rb-0.3.1 vs ext/torch/ext.cpp in torch-rb-0.3.2
- old
+ new
@@ -356,13 +356,19 @@
return self.backward(gradient, create_graph, retain_graph);
})
.define_method(
"grad",
*[](Tensor& self) {
- return self.grad();
+ auto grad = self.grad();
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
})
.define_method(
+ "grad=",
+ *[](Tensor& self, torch::Tensor& grad) {
+ self.grad() = grad;
+ })
+ .define_method(
"_dtype",
*[](Tensor& self) {
return (int) at::typeMetaToScalarType(self.dtype());
})
.define_method(
@@ -578,9 +584,14 @@
.define_method(
"grad",
*[](Parameter& self) {
auto grad = self.grad();
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
+ })
+ .define_method(
+ "grad=",
+ *[](Parameter& self, torch::Tensor& grad) {
+ self.grad() = grad;
});
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
.define_constructor(Constructor<torch::Device, std::string>())
.add_handler<torch::Error>(handle_error)