ext/torch/wrap_outputs.h in torch-rb-0.6.0 vs ext/torch/wrap_outputs.h in torch-rb-0.7.0
- old
+ new
@@ -1,99 +1,106 @@
#pragma once
#include <torch/torch.h>
-#include <rice/Object.hpp>
+#include <rice/rice.hpp>
-inline Object wrap(bool x) {
- return to_ruby<bool>(x);
+inline VALUE wrap(bool x) {
+ return Rice::detail::To_Ruby<bool>().convert(x);
}
-inline Object wrap(int64_t x) {
- return to_ruby<int64_t>(x);
+inline VALUE wrap(int64_t x) {
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
}
-inline Object wrap(double x) {
- return to_ruby<double>(x);
+inline VALUE wrap(double x) {
+ return Rice::detail::To_Ruby<double>().convert(x);
}
-inline Object wrap(torch::Tensor x) {
- return to_ruby<torch::Tensor>(x);
+inline VALUE wrap(torch::Tensor x) {
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
}
-inline Object wrap(torch::Scalar x) {
- return to_ruby<torch::Scalar>(x);
+inline VALUE wrap(torch::Scalar x) {
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
}
-inline Object wrap(torch::ScalarType x) {
- return to_ruby<torch::ScalarType>(x);
+inline VALUE wrap(torch::ScalarType x) {
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
}
-inline Object wrap(torch::QScheme x) {
- return to_ruby<torch::QScheme>(x);
+inline VALUE wrap(torch::QScheme x) {
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
+ return rb_ary_new3(
+ 2,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
+ );
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
+ return rb_ary_new3(
+ 3,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
+ );
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
+ return rb_ary_new3(
+ 4,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
+ );
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
+ return rb_ary_new3(
+ 5,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
+ );
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
- a.push(to_ruby<int64_t>(std::get<3>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
+ return rb_ary_new3(
+ 4,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
+ );
}
-inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
- Array a;
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
- a.push(to_ruby<double>(std::get<2>(x)));
- a.push(to_ruby<int64_t>(std::get<3>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
+ return rb_ary_new3(
+ 4,
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
+ );
}
-inline Object wrap(torch::TensorList x) {
- Array a;
- for (auto& t : x) {
- a.push(to_ruby<torch::Tensor>(t));
+inline VALUE wrap(torch::TensorList x) {
+ auto a = rb_ary_new2(x.size());
+ for (auto t : x) {
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
}
- return Object(a);
+ return a;
}
-inline Object wrap(std::tuple<double, double> x) {
- Array a;
- a.push(to_ruby<double>(std::get<0>(x)));
- a.push(to_ruby<double>(std::get<1>(x)));
- return Object(a);
+inline VALUE wrap(std::tuple<double, double> x) {
+ return rb_ary_new3(
+ 2,
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
+ );
}