#include #include #include "torch_functions.h" #include "templates.h" #include "utils.h" template torch::Tensor make_tensor(Rice::Array a, std::vector size, const torch::TensorOptions &options) { std::vector vec; for (long i = 0; i < a.size(); i++) { vec.push_back(Rice::detail::From_Ruby().convert(a[i].value())); } // hack for requires_grad error auto requires_grad = options.requires_grad(); torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt)); if (requires_grad) { t.set_requires_grad(true); } return t.reshape(size); } void init_torch(Rice::Module& m) { m.add_handler(handle_error); add_torch_functions(m); m.define_singleton_function( "grad_enabled?", []() { return torch::GradMode::is_enabled(); }) .define_singleton_function( "_set_grad_enabled", [](bool enabled) { torch::GradMode::set_enabled(enabled); }) .define_singleton_function( "manual_seed", [](uint64_t seed) { return torch::manual_seed(seed); }) // config .define_singleton_function( "show_config", [] { return torch::show_config(); }) .define_singleton_function( "parallel_info", [] { return torch::get_parallel_info(); }) // begin operations .define_singleton_function( "_save", [](const torch::IValue &value) { auto v = torch::pickle_save(value); std::string str(v.begin(), v.end()); return str; }) .define_singleton_function( "_load", [](const std::string &s) { std::vector v; std::copy(s.begin(), s.end(), std::back_inserter(v)); // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701 return torch::pickle_load(v); }) .define_singleton_function( "_from_blob", [](Rice::String s, std::vector size, const torch::TensorOptions &options) { void *data = const_cast(s.c_str()); return torch::from_blob(data, size, options); }) .define_singleton_function( "_tensor", [](Rice::Array a, std::vector size, const torch::TensorOptions &options) { auto dtype = options.dtype(); if (dtype == torch::kByte) { return make_tensor(a, size, options); } else if (dtype == torch::kChar) { return make_tensor(a, size, options); } else if (dtype == torch::kShort) { return make_tensor(a, size, options); } else if (dtype == torch::kInt) { return make_tensor(a, size, options); } else if (dtype == torch::kLong) { return make_tensor(a, size, options); } else if (dtype == torch::kFloat) { return make_tensor(a, size, options); } else if (dtype == torch::kDouble) { return make_tensor(a, size, options); } else if (dtype == torch::kBool) { return make_tensor(a, size, options); } else if (dtype == torch::kComplexFloat) { return make_tensor>(a, size, options); } else if (dtype == torch::kComplexDouble) { return make_tensor>(a, size, options); } else { throw std::runtime_error("Unsupported type"); } }); }