ext/torch/torch.cpp in torch-rb-0.6.0 vs ext/torch/torch.cpp in torch-rb-0.7.0

- old
+ new

@@ -1,78 +1,87 @@ #include <torch/torch.h> -#include <rice/Module.hpp> +#include <rice/rice.hpp> #include "torch_functions.h" #include "templates.h" #include "utils.h" void init_torch(Rice::Module& m) { m.add_handler<torch::Error>(handle_error); add_torch_functions(m); - m.define_singleton_method( + m.define_singleton_function( "grad_enabled?", - *[]() { + []() { return torch::GradMode::is_enabled(); }) - .define_singleton_method( + .define_singleton_function( "_set_grad_enabled", - *[](bool enabled) { + [](bool enabled) { torch::GradMode::set_enabled(enabled); }) - .define_singleton_method( + .define_singleton_function( "manual_seed", - *[](uint64_t seed) { + [](uint64_t seed) { return torch::manual_seed(seed); }) // config - .define_singleton_method( + .define_singleton_function( "show_config", - *[] { + [] { return torch::show_config(); }) - .define_singleton_method( + .define_singleton_function( "parallel_info", - *[] { + [] { return torch::get_parallel_info(); }) // begin operations - .define_singleton_method( + .define_singleton_function( "_save", - *[](const torch::IValue &value) { + [](const torch::IValue &value) { auto v = torch::pickle_save(value); std::string str(v.begin(), v.end()); return str; }) - .define_singleton_method( + .define_singleton_function( "_load", - *[](const std::string &s) { + [](const std::string &s) { std::vector<char> v; std::copy(s.begin(), s.end(), std::back_inserter(v)); // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701 return torch::pickle_load(v); }) - .define_singleton_method( + .define_singleton_function( "_from_blob", - *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) { + [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) { void *data = const_cast<char *>(s.c_str()); return torch::from_blob(data, size, options); }) - .define_singleton_method( + .define_singleton_function( "_tensor", - *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) { + [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) { auto dtype = options.dtype(); torch::Tensor t; if (dtype == torch::kBool) { std::vector<uint8_t> vec; for (long i = 0; i < a.size(); i++) { - vec.push_back(from_ruby<bool>(a[i])); + vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value())); } t = torch::tensor(vec, options); + } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) { + // TODO use template + std::vector<c10::complex<double>> vec; + Object obj; + for (long i = 0; i < a.size(); i++) { + obj = a[i]; + vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value()))); + } + t = torch::tensor(vec, options); } else { std::vector<float> vec; for (long i = 0; i < a.size(); i++) { - vec.push_back(from_ruby<float>(a[i])); + vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value())); } // hack for requires_grad error if (options.requires_grad()) { t = torch::tensor(vec, options.requires_grad(c10::nullopt)); t.set_requires_grad(true);