ext/torch/wrap_outputs.h in torch-rb-0.6.0 vs ext/torch/wrap_outputs.h in torch-rb-0.7.0

- old
+ new

@@ -1,99 +1,106 @@ #pragma once #include <torch/torch.h> -#include <rice/Object.hpp> +#include <rice/rice.hpp> -inline Object wrap(bool x) { - return to_ruby<bool>(x); +inline VALUE wrap(bool x) { + return Rice::detail::To_Ruby<bool>().convert(x); } -inline Object wrap(int64_t x) { - return to_ruby<int64_t>(x); +inline VALUE wrap(int64_t x) { + return Rice::detail::To_Ruby<int64_t>().convert(x); } -inline Object wrap(double x) { - return to_ruby<double>(x); +inline VALUE wrap(double x) { + return Rice::detail::To_Ruby<double>().convert(x); } -inline Object wrap(torch::Tensor x) { - return to_ruby<torch::Tensor>(x); +inline VALUE wrap(torch::Tensor x) { + return Rice::detail::To_Ruby<torch::Tensor>().convert(x); } -inline Object wrap(torch::Scalar x) { - return to_ruby<torch::Scalar>(x); +inline VALUE wrap(torch::Scalar x) { + return Rice::detail::To_Ruby<torch::Scalar>().convert(x); } -inline Object wrap(torch::ScalarType x) { - return to_ruby<torch::ScalarType>(x); +inline VALUE wrap(torch::ScalarType x) { + return Rice::detail::To_Ruby<torch::ScalarType>().convert(x); } -inline Object wrap(torch::QScheme x) { - return to_ruby<torch::QScheme>(x); +inline VALUE wrap(torch::QScheme x) { + return Rice::detail::To_Ruby<torch::QScheme>().convert(x); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) { + return rb_ary_new3( + 2, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)) + ); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - a.push(to_ruby<torch::Tensor>(std::get<2>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) { + return rb_ary_new3( + 3, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)) + ); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - a.push(to_ruby<torch::Tensor>(std::get<2>(x))); - a.push(to_ruby<torch::Tensor>(std::get<3>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) { + return rb_ary_new3( + 4, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)) + ); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - a.push(to_ruby<torch::Tensor>(std::get<2>(x))); - a.push(to_ruby<torch::Tensor>(std::get<3>(x))); - a.push(to_ruby<torch::Tensor>(std::get<4>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) { + return rb_ary_new3( + 5, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x)) + ); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - a.push(to_ruby<torch::Tensor>(std::get<2>(x))); - a.push(to_ruby<int64_t>(std::get<3>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) { + return rb_ary_new3( + 4, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)), + Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x)) + ); } -inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) { - Array a; - a.push(to_ruby<torch::Tensor>(std::get<0>(x))); - a.push(to_ruby<torch::Tensor>(std::get<1>(x))); - a.push(to_ruby<double>(std::get<2>(x))); - a.push(to_ruby<int64_t>(std::get<3>(x))); - return Object(a); +inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) { + return rb_ary_new3( + 4, + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)), + Rice::detail::To_Ruby<double>().convert(std::get<2>(x)), + Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x)) + ); } -inline Object wrap(torch::TensorList x) { - Array a; - for (auto& t : x) { - a.push(to_ruby<torch::Tensor>(t)); +inline VALUE wrap(torch::TensorList x) { + auto a = rb_ary_new2(x.size()); + for (auto t : x) { + rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t)); } - return Object(a); + return a; } -inline Object wrap(std::tuple<double, double> x) { - Array a; - a.push(to_ruby<double>(std::get<0>(x))); - a.push(to_ruby<double>(std::get<1>(x))); - return Object(a); +inline VALUE wrap(std::tuple<double, double> x) { + return rb_ary_new3( + 2, + Rice::detail::To_Ruby<double>().convert(std::get<0>(x)), + Rice::detail::To_Ruby<double>().convert(std::get<1>(x)) + ); }