ext/torch/torch.cpp in torch-rb-0.9.1 vs ext/torch/torch.cpp in torch-rb-0.9.2

- old
+ new

@@ -4,10 +4,27 @@ #include "torch_functions.h" #include "templates.h" #include "utils.h" +template<typename T> +torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) { + std::vector<T> vec; + for (long i = 0; i < a.size(); i++) { + vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value())); + } + + // hack for requires_grad error + auto requires_grad = options.requires_grad(); + torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt)); + if (requires_grad) { + t.set_requires_grad(true); + } + + return t.reshape(size); +} + void init_torch(Rice::Module& m) { m.add_handler<torch::Error>(handle_error); add_torch_functions(m); m.define_singleton_function( "grad_enabled?", @@ -59,37 +76,30 @@ }) .define_singleton_function( "_tensor", [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) { auto dtype = options.dtype(); - torch::Tensor t; - if (dtype == torch::kBool) { - std::vector<uint8_t> vec; - for (long i = 0; i < a.size(); i++) { - vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value())); - } - t = torch::tensor(vec, options); - } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) { - // TODO use template - std::vector<c10::complex<double>> vec; - Object obj; - for (long i = 0; i < a.size(); i++) { - obj = a[i]; - vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value()))); - } - t = torch::tensor(vec, options); + if (dtype == torch::kByte) { + return make_tensor<uint8_t>(a, size, options); + } else if (dtype == torch::kChar) { + return make_tensor<int8_t>(a, size, options); + } else if (dtype == torch::kShort) { + return make_tensor<int16_t>(a, size, options); + } else if (dtype == torch::kInt) { + return make_tensor<int32_t>(a, size, options); + } else if (dtype == torch::kLong) { + return make_tensor<int64_t>(a, size, options); + } else if (dtype == torch::kFloat) { + return make_tensor<float>(a, size, options); + } else if (dtype == torch::kDouble) { + return make_tensor<double>(a, size, options); + } else if (dtype == torch::kBool) { + return make_tensor<uint8_t>(a, size, options); + } else if (dtype == torch::kComplexFloat) { + return make_tensor<c10::complex<float>>(a, size, options); + } else if (dtype == torch::kComplexDouble) { + return make_tensor<c10::complex<double>>(a, size, options); } else { - std::vector<float> vec; - for (long i = 0; i < a.size(); i++) { - vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value())); - } - // hack for requires_grad error - if (options.requires_grad()) { - t = torch::tensor(vec, options.requires_grad(c10::nullopt)); - t.set_requires_grad(true); - } else { - t = torch::tensor(vec, options); - } + throw std::runtime_error("Unsupported type"); } - return t.reshape(size); }); }