ext/torch/ext.cpp in torch-rb-0.1.0 vs ext/torch/ext.cpp in torch-rb-0.1.1

- old
+ new

@@ -439,23 +439,32 @@ *[](torch::Tensor& self) { Array a; auto dtype = self.dtype(); // TODO DRY if someone knows C++ - // TODO kByte (uint8), kChar (int8), kBool (bool) - if (dtype == torch::kShort) { - short* data = self.data_ptr<short>(); + if (dtype == torch::kByte) { + uint8_t* data = self.data_ptr<uint8_t>(); for (int i = 0; i < self.numel(); i++) { a.push(data[i]); } + } else if (dtype == torch::kChar) { + int8_t* data = self.data_ptr<int8_t>(); + for (int i = 0; i < self.numel(); i++) { + a.push(to_ruby<int>(data[i])); + } + } else if (dtype == torch::kShort) { + int16_t* data = self.data_ptr<int16_t>(); + for (int i = 0; i < self.numel(); i++) { + a.push(data[i]); + } } else if (dtype == torch::kInt) { - int* data = self.data_ptr<int>(); + int32_t* data = self.data_ptr<int32_t>(); for (int i = 0; i < self.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kLong) { - long long* data = self.data_ptr<long long>(); + int64_t* data = self.data_ptr<int64_t>(); for (int i = 0; i < self.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kFloat) { float* data = self.data_ptr<float>(); @@ -465,12 +474,15 @@ } else if (dtype == torch::kDouble) { double* data = self.data_ptr<double>(); for (int i = 0; i < self.numel(); i++) { a.push(data[i]); } + } else if (dtype == torch::kBool) { + // bool + throw std::runtime_error("Type not supported yet"); } else { - throw "Unsupported type"; + throw std::runtime_error("Unsupported type"); } return a; }) .define_method( "_size", @@ -497,12 +509,15 @@ "layout", *[](torch::TensorOptions& self, std::string layout) { torch::Layout l; if (layout == "strided") { l = torch::kStrided; + } else if (layout == "sparse") { + l = torch::kSparse; + throw std::runtime_error("Sparse layout not supported yet"); } else { - throw "Unsupported layout"; + throw std::runtime_error("Unsupported layout: " + layout); } return self.layout(l); }) .define_method( "device", @@ -511,10 +526,10 @@ if (device == "cpu") { d = torch::kCPU; } else if (device == "cuda") { d = torch::kCUDA; } else { - throw "Unsupported device"; + throw std::runtime_error("Unsupported device: " + device); } return self.device(d); }) .define_method( "requires_grad",