Sha256: 32a92ef925321dbefff348c037e67819ff44d13464a64de3d492618a15a3bb95
Contents?: true
Size: 1.21 KB
Versions: 3
Compression:
Stored size: 1.21 KB
Contents
#include <torch/torch.h> #include <rice/rice.hpp> #include "utils.h" void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) { // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp rb_cGenerator .add_handler<torch::Error>(handle_error) .define_singleton_function( "new", []() { // TODO support more devices return torch::make_generator<torch::CPUGeneratorImpl>(); }) .define_method( "device", [](torch::Generator& self) { return self.device(); }) .define_method( "initial_seed", [](torch::Generator& self) { return self.current_seed(); }) .define_method( "manual_seed", [](torch::Generator& self, uint64_t seed) { self.set_current_seed(seed); return self; }) .define_method( "seed", [](torch::Generator& self) { return self.seed(); }) .define_method( "state", [](torch::Generator& self) { return self.get_state(); }) .define_method( "state=", [](torch::Generator& self, const torch::Tensor& state) { self.set_state(state); }); THPGeneratorClass = rb_cGenerator.value(); }
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.13.0 | ext/torch/generator.cpp |
torch-rb-0.12.2 | ext/torch/generator.cpp |
torch-rb-0.12.1 | ext/torch/generator.cpp |