ext/torch/ext.cpp in torch-rb-0.3.1 vs ext/torch/ext.cpp in torch-rb-0.3.2

- old
+ new

@@ -356,13 +356,19 @@ return self.backward(gradient, create_graph, retain_graph); }) .define_method( "grad", *[](Tensor& self) { - return self.grad(); + auto grad = self.grad(); + return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; }) .define_method( + "grad=", + *[](Tensor& self, torch::Tensor& grad) { + self.grad() = grad; + }) + .define_method( "_dtype", *[](Tensor& self) { return (int) at::typeMetaToScalarType(self.dtype()); }) .define_method( @@ -578,9 +584,14 @@ .define_method( "grad", *[](Parameter& self) { auto grad = self.grad(); return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; + }) + .define_method( + "grad=", + *[](Parameter& self, torch::Tensor& grad) { + self.grad() = grad; }); Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device") .define_constructor(Constructor<torch::Device, std::string>()) .add_handler<torch::Error>(handle_error)