ext/torch/ext.cpp in torch-rb-0.2.4 vs ext/torch/ext.cpp in torch-rb-0.2.5
- old
+ new
@@ -21,28 +21,32 @@
class Parameter: public torch::autograd::Variable {
public:
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
};
-void handle_error(c10::Error const & ex)
+void handle_error(torch::Error const & ex)
{
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
}
extern "C"
void Init_ext()
{
Module rb_mTorch = define_module("Torch");
+ rb_mTorch.add_handler<torch::Error>(handle_error);
add_torch_functions(rb_mTorch);
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
+ rb_cTensor.add_handler<torch::Error>(handle_error);
add_tensor_functions(rb_cTensor);
Module rb_mNN = define_module_under(rb_mTorch, "NN");
+ rb_mNN.add_handler<torch::Error>(handle_error);
add_nn_functions(rb_mNN);
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
+ .add_handler<torch::Error>(handle_error)
.define_singleton_method(
"initial_seed",
*[]() {
return at::detail::getDefaultCPUGenerator()->current_seed();
})
@@ -53,10 +57,11 @@
return at::detail::getDefaultCPUGenerator()->seed();
});
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
+ .add_handler<torch::Error>(handle_error)
.define_constructor(Constructor<torch::IValue>())
.define_method("bool?", &torch::IValue::isBool)
.define_method("bool_list?", &torch::IValue::isBoolList)
.define_method("capsule?", &torch::IValue::isCapsule)
.define_method("custom_class?", &torch::IValue::isCustomClass)
@@ -315,11 +320,10 @@
}
return t.reshape(size);
});
rb_cTensor
- .add_handler<c10::Error>(handle_error)
.define_method("cuda?", &torch::Tensor::is_cuda)
.define_method("sparse?", &torch::Tensor::is_sparse)
.define_method("quantized?", &torch::Tensor::is_quantized)
.define_method("dim", &torch::Tensor::dim)
.define_method("numel", &torch::Tensor::numel)
@@ -373,10 +377,25 @@
std::stringstream s;
s << self.device();
return s.str();
})
.define_method(
+ "_data_str",
+ *[](Tensor& self) {
+ Tensor tensor = self;
+
+ // move to CPU to get data
+ if (tensor.device().type() != torch::kCPU) {
+ torch::Device device("cpu");
+ tensor = tensor.to(device);
+ }
+
+ auto data_ptr = (const char *) tensor.data_ptr();
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
+ })
+ // TODO figure out a better way to do this
+ .define_method(
"_flat_data",
*[](Tensor& self) {
Tensor tensor = self;
// move to CPU to get data
@@ -386,50 +405,44 @@
}
Array a;
auto dtype = tensor.dtype();
+ Tensor view = tensor.reshape({tensor.numel()});
+
// TODO DRY if someone knows C++
if (dtype == torch::kByte) {
- uint8_t* data = tensor.data_ptr<uint8_t>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<uint8_t>());
}
} else if (dtype == torch::kChar) {
- int8_t* data = tensor.data_ptr<int8_t>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(to_ruby<int>(data[i]));
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
}
} else if (dtype == torch::kShort) {
- int16_t* data = tensor.data_ptr<int16_t>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<int16_t>());
}
} else if (dtype == torch::kInt) {
- int32_t* data = tensor.data_ptr<int32_t>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<int32_t>());
}
} else if (dtype == torch::kLong) {
- int64_t* data = tensor.data_ptr<int64_t>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<int64_t>());
}
} else if (dtype == torch::kFloat) {
- float* data = tensor.data_ptr<float>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<float>());
}
} else if (dtype == torch::kDouble) {
- double* data = tensor.data_ptr<double>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i]);
+ a.push(view[i].item().to<double>());
}
} else if (dtype == torch::kBool) {
- bool* data = tensor.data_ptr<bool>();
for (int i = 0; i < tensor.numel(); i++) {
- a.push(data[i] ? True : False);
+ a.push(view[i].item().to<bool>() ? True : False);
}
} else {
throw std::runtime_error("Unsupported type");
}
return a;
@@ -447,11 +460,11 @@
auto var = data.set_requires_grad(requires_grad);
return Parameter(std::move(var));
});
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
- .add_handler<c10::Error>(handle_error)
+ .add_handler<torch::Error>(handle_error)
.define_constructor(Constructor<torch::TensorOptions>())
.define_method(
"dtype",
*[](torch::TensorOptions& self, int dtype) {
return self.dtype((torch::ScalarType) dtype);
@@ -553,19 +566,21 @@
*[](Tensor tensor, double sparsity, double std) {
return torch::nn::init::sparse_(tensor, sparsity, std);
});
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
+ .add_handler<torch::Error>(handle_error)
.define_method(
"grad",
*[](Parameter& self) {
auto grad = self.grad();
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
});
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
.define_constructor(Constructor<torch::Device, std::string>())
+ .add_handler<torch::Error>(handle_error)
.define_method("index", &torch::Device::index)
.define_method("index?", &torch::Device::has_index)
.define_method(
"type",
*[](torch::Device& self) {
@@ -573,8 +588,9 @@
s << self.type();
return s.str();
});
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
+ .add_handler<torch::Error>(handle_error)
.define_singleton_method("available?", &torch::cuda::is_available)
.define_singleton_method("device_count", &torch::cuda::device_count);
}