ext/torch/ext.cpp in torch-rb-0.4.2 vs ext/torch/ext.cpp in torch-rb-0.5.0
- old
+ new
@@ -347,20 +347,10 @@
"contiguous?",
*[](Tensor& self) {
return self.is_contiguous();
})
.define_method(
- "addcmul!",
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
- return self.addcmul_(tensor1, tensor2, value);
- })
- .define_method(
- "addcdiv!",
- *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
- return self.addcdiv_(tensor1, tensor2, value);
- })
- .define_method(
"_requires_grad!",
*[](Tensor& self, bool requires_grad) {
return self.set_requires_grad(requires_grad);
})
.define_method(
@@ -370,11 +360,11 @@
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
})
.define_method(
"grad=",
*[](Tensor& self, torch::Tensor& grad) {
- self.grad() = grad;
+ self.mutable_grad() = grad;
})
.define_method(
"_dtype",
*[](Tensor& self) {
return (int) at::typeMetaToScalarType(self.dtype());
@@ -607,10 +597,10 @@
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
})
.define_method(
"grad=",
*[](Parameter& self, torch::Tensor& grad) {
- self.grad() = grad;
+ self.mutable_grad() = grad;
});
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
.add_handler<torch::Error>(handle_error)
.define_constructor(Constructor<torch::Device, std::string>())