ext/torch/templates.cpp in torch-rb-0.3.6 vs ext/torch/templates.cpp in torch-rb-0.3.7

- old
+ new

@@ -1,8 +1,36 @@ #include <torch/torch.h> #include <rice/Object.hpp> #include "templates.hpp" +Object wrap(bool x) { + return to_ruby<bool>(x); +} + +Object wrap(int64_t x) { + return to_ruby<int64_t>(x); +} + +Object wrap(double x) { + return to_ruby<double>(x); +} + +Object wrap(torch::Tensor x) { + return to_ruby<torch::Tensor>(x); +} + +Object wrap(torch::Scalar x) { + return to_ruby<torch::Scalar>(x); +} + +Object wrap(torch::ScalarType x) { + return to_ruby<torch::ScalarType>(x); +} + +Object wrap(torch::QScheme x) { + return to_ruby<torch::QScheme>(x); +} + Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) { Array a; a.push(to_ruby<torch::Tensor>(std::get<0>(x))); a.push(to_ruby<torch::Tensor>(std::get<1>(x))); return Object(a);