ext/torch/ext.cpp in torch-rb-0.1.8 vs ext/torch/ext.cpp in torch-rb-0.2.0
- old
+ new
@@ -129,16 +129,19 @@
void *data = const_cast<char *>(s.c_str());
return torch::from_blob(data, size, options);
})
.define_singleton_method(
"_tensor",
- *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
- Array a = Array(o);
+ *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
auto dtype = options.dtype();
torch::Tensor t;
if (dtype == torch::kBool) {
- throw std::runtime_error("Cannot create bool from tensor method yet");
+ std::vector<uint8_t> vec;
+ for (size_t i = 0; i < a.size(); i++) {
+ vec.push_back(from_ruby<bool>(a[i]));
+ }
+ t = torch::tensor(vec, options);
} else {
std::vector<float> vec;
for (size_t i = 0; i < a.size(); i++) {
vec.push_back(from_ruby<float>(a[i]));
}
@@ -211,52 +214,60 @@
return s.str();
})
.define_method(
"_flat_data",
*[](Tensor& self) {
+ Tensor tensor = self;
+
+ // move to CPU to get data
+ if (tensor.device().type() != torch::kCPU) {
+ torch::Device device("cpu");
+ tensor = tensor.to(device);
+ }
+
Array a;
- auto dtype = self.dtype();
+ auto dtype = tensor.dtype();
// TODO DRY if someone knows C++
if (dtype == torch::kByte) {
- uint8_t* data = self.data_ptr<uint8_t>();
- for (int i = 0; i < self.numel(); i++) {
+ uint8_t* data = tensor.data_ptr<uint8_t>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kChar) {
- int8_t* data = self.data_ptr<int8_t>();
- for (int i = 0; i < self.numel(); i++) {
+ int8_t* data = tensor.data_ptr<int8_t>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(to_ruby<int>(data[i]));
}
} else if (dtype == torch::kShort) {
- int16_t* data = self.data_ptr<int16_t>();
- for (int i = 0; i < self.numel(); i++) {
+ int16_t* data = tensor.data_ptr<int16_t>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kInt) {
- int32_t* data = self.data_ptr<int32_t>();
- for (int i = 0; i < self.numel(); i++) {
+ int32_t* data = tensor.data_ptr<int32_t>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kLong) {
- int64_t* data = self.data_ptr<int64_t>();
- for (int i = 0; i < self.numel(); i++) {
+ int64_t* data = tensor.data_ptr<int64_t>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kFloat) {
- float* data = self.data_ptr<float>();
- for (int i = 0; i < self.numel(); i++) {
+ float* data = tensor.data_ptr<float>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kDouble) {
- double* data = self.data_ptr<double>();
- for (int i = 0; i < self.numel(); i++) {
+ double* data = tensor.data_ptr<double>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kBool) {
- bool* data = self.data_ptr<bool>();
- for (int i = 0; i < self.numel(); i++) {
+ bool* data = tensor.data_ptr<bool>();
+ for (int i = 0; i < tensor.numel(); i++) {
a.push(data[i] ? True : False);
}
} else {
throw std::runtime_error("Unsupported type");
}
@@ -298,18 +309,16 @@
return self.layout(l);
})
.define_method(
"device",
*[](torch::TensorOptions& self, std::string device) {
- torch::DeviceType d;
- if (device == "cpu") {
- d = torch::kCPU;
- } else if (device == "cuda") {
- d = torch::kCUDA;
- } else {
- throw std::runtime_error("Unsupported device: " + device);
+ try {
+ // needed to catch exception
+ torch::Device d(device);
+ return self.device(d);
+ } catch (const c10::Error& error) {
+ throw std::runtime_error(error.what_without_backtrace());
}
- return self.device(d);
})
.define_method(
"requires_grad",
*[](torch::TensorOptions& self, bool requires_grad) {
return self.requires_grad(requires_grad);