Sha256: 704d785aa9ebbdea47214a12b3731f89ac25640704afa972ae46394279bbdfcb

Contents?: true

Size: 1.94 KB

Versions: 1

Compression:

Stored size: 1.94 KB

Contents

#include <torch/torch.h>
#include <rice/Object.hpp>
#include "templates.hpp"

Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  return Object(a);
}

Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
  return Object(a);
}

Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
  return Object(a);
}

Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
  return Object(a);
}

Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
  a.push(to_ruby<int64_t>(std::get<3>(x)));
  return Object(a);
}

Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
  Array a;
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
  a.push(to_ruby<double>(std::get<2>(x)));
  a.push(to_ruby<int64_t>(std::get<3>(x)));
  return Object(a);
}

Object wrap(std::vector<torch::Tensor> x) {
  Array a;
  for (auto& t : x) {
    a.push(to_ruby<torch::Tensor>(t));
  }
  return Object(a);
}

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.3.6 ext/torch/templates.cpp