ext/torch/ext.cpp in torch-rb-0.6.0 vs ext/torch/ext.cpp in torch-rb-0.7.0
- old
+ new
@@ -1,26 +1,35 @@
-#include <rice/Module.hpp>
+#include <torch/torch.h>
+#include <rice/rice.hpp>
+
void init_nn(Rice::Module& m);
-void init_tensor(Rice::Module& m);
+void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
void init_torch(Rice::Module& m);
void init_cuda(Rice::Module& m);
void init_device(Rice::Module& m);
-void init_ivalue(Rice::Module& m);
+void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
void init_random(Rice::Module& m);
extern "C"
void Init_ext()
{
auto m = Rice::define_module("Torch");
+ // need to define certain classes up front to keep Rice happy
+ auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
+ .define_constructor(Rice::Constructor<torch::IValue>());
+ auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
+ auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
+ .define_constructor(Rice::Constructor<torch::TensorOptions>());
+
// keep this order
init_torch(m);
- init_tensor(m);
+ init_tensor(m, rb_cTensor, rb_cTensorOptions);
init_nn(m);
init_cuda(m);
init_device(m);
- init_ivalue(m);
+ init_ivalue(m, rb_cIValue);
init_random(m);
}