ext/torch/ext.cpp in torch-rb-0.12.0 vs ext/torch/ext.cpp in torch-rb-0.12.1
- old
+ new
@@ -10,10 +10,11 @@
void init_torch(Rice::Module& m);
void init_backends(Rice::Module& m);
void init_cuda(Rice::Module& m);
void init_device(Rice::Module& m);
+void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
void init_random(Rice::Module& m);
extern "C"
void Init_ext()
@@ -21,10 +22,11 @@
auto m = Rice::define_module("Torch");
// need to define certain classes up front to keep Rice happy
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
.define_constructor(Rice::Constructor<torch::IValue>());
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
.define_constructor(Rice::Constructor<torch::TensorOptions>());
// keep this order
@@ -36,8 +38,9 @@
init_special(m);
init_backends(m);
init_cuda(m);
init_device(m);
+ init_generator(m, rb_cGenerator);
init_ivalue(m, rb_cIValue);
init_random(m);
}