ext/torch/ext.cpp in torch-rb-0.1.0 vs ext/torch/ext.cpp in torch-rb-0.1.1
- old
+ new
@@ -439,23 +439,32 @@
*[](torch::Tensor& self) {
Array a;
auto dtype = self.dtype();
// TODO DRY if someone knows C++
- // TODO kByte (uint8), kChar (int8), kBool (bool)
- if (dtype == torch::kShort) {
- short* data = self.data_ptr<short>();
+ if (dtype == torch::kByte) {
+ uint8_t* data = self.data_ptr<uint8_t>();
for (int i = 0; i < self.numel(); i++) {
a.push(data[i]);
}
+ } else if (dtype == torch::kChar) {
+ int8_t* data = self.data_ptr<int8_t>();
+ for (int i = 0; i < self.numel(); i++) {
+ a.push(to_ruby<int>(data[i]));
+ }
+ } else if (dtype == torch::kShort) {
+ int16_t* data = self.data_ptr<int16_t>();
+ for (int i = 0; i < self.numel(); i++) {
+ a.push(data[i]);
+ }
} else if (dtype == torch::kInt) {
- int* data = self.data_ptr<int>();
+ int32_t* data = self.data_ptr<int32_t>();
for (int i = 0; i < self.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kLong) {
- long long* data = self.data_ptr<long long>();
+ int64_t* data = self.data_ptr<int64_t>();
for (int i = 0; i < self.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kFloat) {
float* data = self.data_ptr<float>();
@@ -465,12 +474,15 @@
} else if (dtype == torch::kDouble) {
double* data = self.data_ptr<double>();
for (int i = 0; i < self.numel(); i++) {
a.push(data[i]);
}
+ } else if (dtype == torch::kBool) {
+ // bool
+ throw std::runtime_error("Type not supported yet");
} else {
- throw "Unsupported type";
+ throw std::runtime_error("Unsupported type");
}
return a;
})
.define_method(
"_size",
@@ -497,12 +509,15 @@
"layout",
*[](torch::TensorOptions& self, std::string layout) {
torch::Layout l;
if (layout == "strided") {
l = torch::kStrided;
+ } else if (layout == "sparse") {
+ l = torch::kSparse;
+ throw std::runtime_error("Sparse layout not supported yet");
} else {
- throw "Unsupported layout";
+ throw std::runtime_error("Unsupported layout: " + layout);
}
return self.layout(l);
})
.define_method(
"device",
@@ -511,10 +526,10 @@
if (device == "cpu") {
d = torch::kCPU;
} else if (device == "cuda") {
d = torch::kCUDA;
} else {
- throw "Unsupported device";
+ throw std::runtime_error("Unsupported device: " + device);
}
return self.device(d);
})
.define_method(
"requires_grad",