ext/torch/templates.hpp in torch-rb-0.3.6 vs ext/torch/templates.hpp in torch-rb-0.3.7
- old
+ new
@@ -11,53 +11,35 @@
using torch::Device;
using torch::Scalar;
using torch::ScalarType;
using torch::Tensor;
+using torch::IntArrayRef;
+using torch::TensorList;
-// need to wrap torch::IntArrayRef() since
-// it doesn't own underlying data
-class IntArrayRef {
- std::vector<int64_t> vec;
- public:
- IntArrayRef(Object o) {
- Array a = Array(o);
- for (size_t i = 0; i < a.size(); i++) {
- vec.push_back(from_ruby<int64_t>(a[i]));
- }
- }
- operator torch::IntArrayRef() {
- return torch::IntArrayRef(vec);
- }
-};
-
template<>
inline
-IntArrayRef from_ruby<IntArrayRef>(Object x)
+std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
{
- return IntArrayRef(x);
+ Array a = Array(x);
+ std::vector<int64_t> vec(a.size());
+ for (size_t i = 0; i < a.size(); i++) {
+ vec[i] = from_ruby<int64_t>(a[i]);
+ }
+ return vec;
}
-class TensorList {
- std::vector<torch::Tensor> vec;
- public:
- TensorList(Object o) {
- Array a = Array(o);
- for (size_t i = 0; i < a.size(); i++) {
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
- }
- }
- operator torch::TensorList() {
- return torch::TensorList(vec);
- }
-};
-
template<>
inline
-TensorList from_ruby<TensorList>(Object x)
+std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
{
- return TensorList(x);
+ Array a = Array(x);
+ std::vector<Tensor> vec(a.size());
+ for (size_t i = 0; i < a.size(); i++) {
+ vec[i] = from_ruby<Tensor>(a[i]);
+ }
+ return vec;
}
class FanModeType {
std::string s;
public:
@@ -240,9 +222,16 @@
} else {
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
}
}
+Object wrap(bool x);
+Object wrap(int64_t x);
+Object wrap(double x);
+Object wrap(torch::Tensor x);
+Object wrap(torch::Scalar x);
+Object wrap(torch::ScalarType x);
+Object wrap(torch::QScheme x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);