ext/torch/torch.cpp in torch-rb-0.9.1 vs ext/torch/torch.cpp in torch-rb-0.9.2
- old
+ new
@@ -4,10 +4,27 @@
#include "torch_functions.h"
#include "templates.h"
#include "utils.h"
+template<typename T>
+torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
+ std::vector<T> vec;
+ for (long i = 0; i < a.size(); i++) {
+ vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
+ }
+
+ // hack for requires_grad error
+ auto requires_grad = options.requires_grad();
+ torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
+ if (requires_grad) {
+ t.set_requires_grad(true);
+ }
+
+ return t.reshape(size);
+}
+
void init_torch(Rice::Module& m) {
m.add_handler<torch::Error>(handle_error);
add_torch_functions(m);
m.define_singleton_function(
"grad_enabled?",
@@ -59,37 +76,30 @@
})
.define_singleton_function(
"_tensor",
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
auto dtype = options.dtype();
- torch::Tensor t;
- if (dtype == torch::kBool) {
- std::vector<uint8_t> vec;
- for (long i = 0; i < a.size(); i++) {
- vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
- }
- t = torch::tensor(vec, options);
- } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
- // TODO use template
- std::vector<c10::complex<double>> vec;
- Object obj;
- for (long i = 0; i < a.size(); i++) {
- obj = a[i];
- vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
- }
- t = torch::tensor(vec, options);
+ if (dtype == torch::kByte) {
+ return make_tensor<uint8_t>(a, size, options);
+ } else if (dtype == torch::kChar) {
+ return make_tensor<int8_t>(a, size, options);
+ } else if (dtype == torch::kShort) {
+ return make_tensor<int16_t>(a, size, options);
+ } else if (dtype == torch::kInt) {
+ return make_tensor<int32_t>(a, size, options);
+ } else if (dtype == torch::kLong) {
+ return make_tensor<int64_t>(a, size, options);
+ } else if (dtype == torch::kFloat) {
+ return make_tensor<float>(a, size, options);
+ } else if (dtype == torch::kDouble) {
+ return make_tensor<double>(a, size, options);
+ } else if (dtype == torch::kBool) {
+ return make_tensor<uint8_t>(a, size, options);
+ } else if (dtype == torch::kComplexFloat) {
+ return make_tensor<c10::complex<float>>(a, size, options);
+ } else if (dtype == torch::kComplexDouble) {
+ return make_tensor<c10::complex<double>>(a, size, options);
} else {
- std::vector<float> vec;
- for (long i = 0; i < a.size(); i++) {
- vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
- }
- // hack for requires_grad error
- if (options.requires_grad()) {
- t = torch::tensor(vec, options.requires_grad(c10::nullopt));
- t.set_requires_grad(true);
- } else {
- t = torch::tensor(vec, options);
- }
+ throw std::runtime_error("Unsupported type");
}
- return t.reshape(size);
});
}