#include #include #include "torch_functions.h" #include "templates.h" #include "utils.h" void init_torch(Rice::Module& m) { m.add_handler(handle_error); add_torch_functions(m); m.define_singleton_method( "grad_enabled?", *[]() { return torch::GradMode::is_enabled(); }) .define_singleton_method( "_set_grad_enabled", *[](bool enabled) { torch::GradMode::set_enabled(enabled); }) .define_singleton_method( "manual_seed", *[](uint64_t seed) { return torch::manual_seed(seed); }) // config .define_singleton_method( "show_config", *[] { return torch::show_config(); }) .define_singleton_method( "parallel_info", *[] { return torch::get_parallel_info(); }) // begin operations .define_singleton_method( "_save", *[](const torch::IValue &value) { auto v = torch::pickle_save(value); std::string str(v.begin(), v.end()); return str; }) .define_singleton_method( "_load", *[](const std::string &s) { std::vector v; std::copy(s.begin(), s.end(), std::back_inserter(v)); // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701 return torch::pickle_load(v); }) .define_singleton_method( "_from_blob", *[](Rice::String s, std::vector size, const torch::TensorOptions &options) { void *data = const_cast(s.c_str()); return torch::from_blob(data, size, options); }) .define_singleton_method( "_tensor", *[](Rice::Array a, std::vector size, const torch::TensorOptions &options) { auto dtype = options.dtype(); torch::Tensor t; if (dtype == torch::kBool) { std::vector vec; for (long i = 0; i < a.size(); i++) { vec.push_back(from_ruby(a[i])); } t = torch::tensor(vec, options); } else { std::vector vec; for (long i = 0; i < a.size(); i++) { vec.push_back(from_ruby(a[i])); } // hack for requires_grad error if (options.requires_grad()) { t = torch::tensor(vec, options.requires_grad(c10::nullopt)); t.set_requires_grad(true); } else { t = torch::tensor(vec, options); } } return t.reshape(size); }); }