ext/torch/ext.cpp in torch-rb-0.1.7 vs ext/torch/ext.cpp in torch-rb-0.1.8

- old
+ new

@@ -14,10 +14,16 @@ #include "tensor_functions.hpp" #include "nn_functions.hpp" using namespace Rice; +// need to make a distinction between parameters and tensors +class Parameter: public torch::autograd::Variable { + public: + Parameter(Tensor&& t) : torch::autograd::Variable(t) { } +}; + extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch"); add_torch_functions(rb_mTorch); @@ -134,20 +140,27 @@ } else { std::vector<float> vec; for (size_t i = 0; i < a.size(); i++) { vec.push_back(from_ruby<float>(a[i])); } - t = torch::tensor(vec, options); + // hack for requires_grad error + if (options.requires_grad()) { + t = torch::tensor(vec, options.requires_grad(c10::nullopt)); + t.set_requires_grad(true); + } else { + t = torch::tensor(vec, options); + } } return t.reshape(size); }); rb_cTensor .define_method("cuda?", &torch::Tensor::is_cuda) .define_method("sparse?", &torch::Tensor::is_sparse) .define_method("quantized?", &torch::Tensor::is_quantized) .define_method("dim", &torch::Tensor::dim) + .define_method("numel", &torch::Tensor::numel) .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) .define_method( "addcmul!", *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) { @@ -258,11 +271,11 @@ "_make_subclass", *[](Tensor& rd, bool requires_grad) { auto data = torch::autograd::as_variable_ref(rd).detach(); data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); auto var = data.set_requires_grad(requires_grad); - return torch::autograd::Variable(std::move(var)); + return Parameter(std::move(var)); }); Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions") .define_constructor(Constructor<torch::TensorOptions>()) .define_method( @@ -373,13 +386,13 @@ "_sparse!", *[](Tensor tensor, double sparsity, double std) { return torch::nn::init::sparse_(tensor, sparsity, std); }); - Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter") + Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter") .define_method( "grad", - *[](torch::autograd::Variable& self) { + *[](Parameter& self) { auto grad = self.grad(); return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; }); Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")