ext/torch/ext.cpp in torch-rb-0.1.8 vs ext/torch/ext.cpp in torch-rb-0.2.0

- old
+ new

@@ -129,16 +129,19 @@ void *data = const_cast<char *>(s.c_str()); return torch::from_blob(data, size, options); }) .define_singleton_method( "_tensor", - *[](Object o, IntArrayRef size, const torch::TensorOptions &options) { - Array a = Array(o); + *[](Array a, IntArrayRef size, const torch::TensorOptions &options) { auto dtype = options.dtype(); torch::Tensor t; if (dtype == torch::kBool) { - throw std::runtime_error("Cannot create bool from tensor method yet"); + std::vector<uint8_t> vec; + for (size_t i = 0; i < a.size(); i++) { + vec.push_back(from_ruby<bool>(a[i])); + } + t = torch::tensor(vec, options); } else { std::vector<float> vec; for (size_t i = 0; i < a.size(); i++) { vec.push_back(from_ruby<float>(a[i])); } @@ -211,52 +214,60 @@ return s.str(); }) .define_method( "_flat_data", *[](Tensor& self) { + Tensor tensor = self; + + // move to CPU to get data + if (tensor.device().type() != torch::kCPU) { + torch::Device device("cpu"); + tensor = tensor.to(device); + } + Array a; - auto dtype = self.dtype(); + auto dtype = tensor.dtype(); // TODO DRY if someone knows C++ if (dtype == torch::kByte) { - uint8_t* data = self.data_ptr<uint8_t>(); - for (int i = 0; i < self.numel(); i++) { + uint8_t* data = tensor.data_ptr<uint8_t>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kChar) { - int8_t* data = self.data_ptr<int8_t>(); - for (int i = 0; i < self.numel(); i++) { + int8_t* data = tensor.data_ptr<int8_t>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(to_ruby<int>(data[i])); } } else if (dtype == torch::kShort) { - int16_t* data = self.data_ptr<int16_t>(); - for (int i = 0; i < self.numel(); i++) { + int16_t* data = tensor.data_ptr<int16_t>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kInt) { - int32_t* data = self.data_ptr<int32_t>(); - for (int i = 0; i < self.numel(); i++) { + int32_t* data = tensor.data_ptr<int32_t>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kLong) { - int64_t* data = self.data_ptr<int64_t>(); - for (int i = 0; i < self.numel(); i++) { + int64_t* data = tensor.data_ptr<int64_t>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kFloat) { - float* data = self.data_ptr<float>(); - for (int i = 0; i < self.numel(); i++) { + float* data = tensor.data_ptr<float>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kDouble) { - double* data = self.data_ptr<double>(); - for (int i = 0; i < self.numel(); i++) { + double* data = tensor.data_ptr<double>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kBool) { - bool* data = self.data_ptr<bool>(); - for (int i = 0; i < self.numel(); i++) { + bool* data = tensor.data_ptr<bool>(); + for (int i = 0; i < tensor.numel(); i++) { a.push(data[i] ? True : False); } } else { throw std::runtime_error("Unsupported type"); } @@ -298,18 +309,16 @@ return self.layout(l); }) .define_method( "device", *[](torch::TensorOptions& self, std::string device) { - torch::DeviceType d; - if (device == "cpu") { - d = torch::kCPU; - } else if (device == "cuda") { - d = torch::kCUDA; - } else { - throw std::runtime_error("Unsupported device: " + device); + try { + // needed to catch exception + torch::Device d(device); + return self.device(d); + } catch (const c10::Error& error) { + throw std::runtime_error(error.what_without_backtrace()); } - return self.device(d); }) .define_method( "requires_grad", *[](torch::TensorOptions& self, bool requires_grad) { return self.requires_grad(requires_grad);