ext/torch/tensor.cpp in torch-rb-0.9.1 vs ext/torch/tensor.cpp in torch-rb-0.9.2

- old
+ new

@@ -8,32 +8,10 @@ #include "utils.h" using namespace Rice; using torch::indexing::TensorIndex; -namespace Rice::detail -{ - template<typename T> - struct Type<c10::complex<T>> - { - static bool verify() - { - return true; - } - }; - - template<typename T> - class To_Ruby<c10::complex<T>> - { - public: - VALUE convert(c10::complex<T> const& x) - { - return rb_dbl_complex_new(x.real(), x.imag()); - } - }; -} - template<typename T> Array flat_data(Tensor& tensor) { Tensor view = tensor.reshape({tensor.numel()}); Array a; @@ -187,13 +165,35 @@ "grad", [](Tensor& self) { auto grad = self.grad(); return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil; }) + // can't use grad= + // assignment methods fail with Ruby 3.0 .define_method( - "grad=", - [](Tensor& self, torch::Tensor& grad) { + "_set_grad", + [](Tensor& self, Rice::Object value) { + if (value.is_nil()) { + self.mutable_grad().reset(); + return; + } + + const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value()); + + // TODO support sparse grad + if (!grad.options().type_equal(self.options())) { + rb_raise(rb_eArgError, "assigned grad has data of a different type"); + } + + if (self.is_cuda() && grad.get_device() != self.get_device()) { + rb_raise(rb_eArgError, "assigned grad has data located on a different device"); + } + + if (!self.sizes().equals(grad.sizes())) { + rb_raise(rb_eArgError, "assigned grad has data of a different size"); + } + self.mutable_grad() = grad; }) .define_method( "_dtype", [](Tensor& self) { @@ -279,10 +279,10 @@ throw std::runtime_error("Unsupported type"); } }) .define_method( "_to", - [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { + [](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) { return self.to(device, (torch::ScalarType) dtype, non_blocking, copy); }); rb_cTensorOptions .add_handler<torch::Error>(handle_error)