ext/torch/tensor.cpp in torch-rb-0.9.1 vs ext/torch/tensor.cpp in torch-rb-0.9.2
- old
+ new
@@ -8,32 +8,10 @@
#include "utils.h"
using namespace Rice;
using torch::indexing::TensorIndex;
-namespace Rice::detail
-{
- template<typename T>
- struct Type<c10::complex<T>>
- {
- static bool verify()
- {
- return true;
- }
- };
-
- template<typename T>
- class To_Ruby<c10::complex<T>>
- {
- public:
- VALUE convert(c10::complex<T> const& x)
- {
- return rb_dbl_complex_new(x.real(), x.imag());
- }
- };
-}
-
template<typename T>
Array flat_data(Tensor& tensor) {
Tensor view = tensor.reshape({tensor.numel()});
Array a;
@@ -187,13 +165,35 @@
"grad",
[](Tensor& self) {
auto grad = self.grad();
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
})
+ // can't use grad=
+ // assignment methods fail with Ruby 3.0
.define_method(
- "grad=",
- [](Tensor& self, torch::Tensor& grad) {
+ "_set_grad",
+ [](Tensor& self, Rice::Object value) {
+ if (value.is_nil()) {
+ self.mutable_grad().reset();
+ return;
+ }
+
+ const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
+
+ // TODO support sparse grad
+ if (!grad.options().type_equal(self.options())) {
+ rb_raise(rb_eArgError, "assigned grad has data of a different type");
+ }
+
+ if (self.is_cuda() && grad.get_device() != self.get_device()) {
+ rb_raise(rb_eArgError, "assigned grad has data located on a different device");
+ }
+
+ if (!self.sizes().equals(grad.sizes())) {
+ rb_raise(rb_eArgError, "assigned grad has data of a different size");
+ }
+
self.mutable_grad() = grad;
})
.define_method(
"_dtype",
[](Tensor& self) {
@@ -279,10 +279,10 @@
throw std::runtime_error("Unsupported type");
}
})
.define_method(
"_to",
- [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
+ [](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
});
rb_cTensorOptions
.add_handler<torch::Error>(handle_error)