Sha256: a29eafb8b3feacff37f83b5f182041993ce92f3fb20303b7e3d114c3a9191f8b

Contents?: true

Size: 1.17 KB

Versions: 9

Compression:

Stored size: 1.17 KB

Contents

#include <torch/torch.h>

#include <rice/rice.hpp>

#include "utils.h"

void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
  // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
  rb_cGenerator
    .define_singleton_function(
      "new",
      []() {
        // TODO support more devices
        return torch::make_generator<torch::CPUGeneratorImpl>();
      })
    .define_method(
      "device",
      [](torch::Generator& self) {
        return self.device();
      })
    .define_method(
      "initial_seed",
      [](torch::Generator& self) {
        return self.current_seed();
      })
    .define_method(
      "manual_seed",
      [](torch::Generator& self, uint64_t seed) {
        self.set_current_seed(seed);
        return self;
      })
    .define_method(
      "seed",
      [](torch::Generator& self) {
        return self.seed();
      })
    .define_method(
      "state",
      [](torch::Generator& self) {
        return self.get_state();
      })
    .define_method(
      "state=",
      [](torch::Generator& self, const torch::Tensor& state) {
        self.set_state(state);
      });

  THPGeneratorClass = rb_cGenerator.value();
}

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
torch-rb-0.18.0 ext/torch/generator.cpp
torch-rb-0.17.1 ext/torch/generator.cpp
torch-rb-0.17.0 ext/torch/generator.cpp
torch-rb-0.16.0 ext/torch/generator.cpp
torch-rb-0.15.0 ext/torch/generator.cpp
torch-rb-0.14.1 ext/torch/generator.cpp
torch-rb-0.14.0 ext/torch/generator.cpp
torch-rb-0.13.2 ext/torch/generator.cpp
torch-rb-0.13.1 ext/torch/generator.cpp