ext/torch/templates.hpp in torch-rb-0.3.6 vs ext/torch/templates.hpp in torch-rb-0.3.7

- old
+ new

@@ -11,53 +11,35 @@ using torch::Device; using torch::Scalar; using torch::ScalarType; using torch::Tensor; +using torch::IntArrayRef; +using torch::TensorList; -// need to wrap torch::IntArrayRef() since -// it doesn't own underlying data -class IntArrayRef { - std::vector<int64_t> vec; - public: - IntArrayRef(Object o) { - Array a = Array(o); - for (size_t i = 0; i < a.size(); i++) { - vec.push_back(from_ruby<int64_t>(a[i])); - } - } - operator torch::IntArrayRef() { - return torch::IntArrayRef(vec); - } -}; - template<> inline -IntArrayRef from_ruby<IntArrayRef>(Object x) +std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x) { - return IntArrayRef(x); + Array a = Array(x); + std::vector<int64_t> vec(a.size()); + for (size_t i = 0; i < a.size(); i++) { + vec[i] = from_ruby<int64_t>(a[i]); + } + return vec; } -class TensorList { - std::vector<torch::Tensor> vec; - public: - TensorList(Object o) { - Array a = Array(o); - for (size_t i = 0; i < a.size(); i++) { - vec.push_back(from_ruby<torch::Tensor>(a[i])); - } - } - operator torch::TensorList() { - return torch::TensorList(vec); - } -}; - template<> inline -TensorList from_ruby<TensorList>(Object x) +std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x) { - return TensorList(x); + Array a = Array(x); + std::vector<Tensor> vec(a.size()); + for (size_t i = 0; i < a.size(); i++) { + vec[i] = from_ruby<Tensor>(a[i]); + } + return vec; } class FanModeType { std::string s; public: @@ -240,9 +222,16 @@ } else { return torch::optional<Scalar>{from_ruby<Scalar>(x)}; } } +Object wrap(bool x); +Object wrap(int64_t x); +Object wrap(double x); +Object wrap(torch::Tensor x); +Object wrap(torch::Scalar x); +Object wrap(torch::ScalarType x); +Object wrap(torch::QScheme x); Object wrap(std::tuple<torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);