ext/torch/templates.h in torch-rb-0.6.0 vs ext/torch/templates.h in torch-rb-0.7.0
- old
+ new
@@ -2,12 +2,11 @@
#ifdef isfinite
#undef isfinite
#endif
-#include <rice/Array.hpp>
-#include <rice/Object.hpp>
+#include <rice/rice.hpp>
using namespace Rice;
using torch::Device;
using torch::Scalar;
@@ -21,10 +20,13 @@
using torch::IntArrayRef;
using torch::ArrayRef;
using torch::TensorList;
using torch::Storage;
+using torch::nn::init::FanModeType;
+using torch::nn::init::NonlinearityType;
+
#define HANDLE_TH_ERRORS \
try {
#define END_HANDLE_TH_ERRORS \
} catch (const torch::Error& ex) { \
@@ -36,65 +38,72 @@
}
#define RETURN_NIL \
return Qnil;
-template<>
-inline
-std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
-{
- Array a = Array(x);
- std::vector<int64_t> vec(a.size());
- for (long i = 0; i < a.size(); i++) {
- vec[i] = from_ruby<int64_t>(a[i]);
- }
- return vec;
-}
+class OptionalTensor {
+ torch::Tensor value;
+ public:
+ OptionalTensor(Object o) {
+ if (o.is_nil()) {
+ value = {};
+ } else {
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
+ }
+ }
+ OptionalTensor(torch::Tensor o) {
+ value = o;
+ }
+ operator torch::Tensor() const {
+ return value;
+ }
+};
-template<>
-inline
-std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
+namespace Rice::detail
{
- Array a = Array(x);
- std::vector<Tensor> vec(a.size());
- for (long i = 0; i < a.size(); i++) {
- vec[i] = from_ruby<Tensor>(a[i]);
- }
- return vec;
-}
+ template<>
+ struct Type<FanModeType>
+ {
+ static bool verify()
+ {
+ return true;
+ }
+ };
-class FanModeType {
- std::string s;
+ template<>
+ class From_Ruby<FanModeType>
+ {
public:
- FanModeType(Object o) {
- s = String(o).str();
- }
- operator torch::nn::init::FanModeType() {
+ FanModeType convert(VALUE x)
+ {
+ auto s = String(x).str();
if (s == "fan_in") {
return torch::kFanIn;
} else if (s == "fan_out") {
return torch::kFanOut;
} else {
throw std::runtime_error("Unsupported nonlinearity type: " + s);
}
}
-};
+ };
-template<>
-inline
-FanModeType from_ruby<FanModeType>(Object x)
-{
- return FanModeType(x);
-}
+ template<>
+ struct Type<NonlinearityType>
+ {
+ static bool verify()
+ {
+ return true;
+ }
+ };
-class NonlinearityType {
- std::string s;
+ template<>
+ class From_Ruby<NonlinearityType>
+ {
public:
- NonlinearityType(Object o) {
- s = String(o).str();
- }
- operator torch::nn::init::NonlinearityType() {
+ NonlinearityType convert(VALUE x)
+ {
+ auto s = String(x).str();
if (s == "linear") {
return torch::kLinear;
} else if (s == "conv1d") {
return torch::kConv1D;
} else if (s == "conv2d") {
@@ -117,104 +126,72 @@
return torch::kLeakyReLU;
} else {
throw std::runtime_error("Unsupported nonlinearity type: " + s);
}
}
-};
+ };
-template<>
-inline
-NonlinearityType from_ruby<NonlinearityType>(Object x)
-{
- return NonlinearityType(x);
-}
+ template<>
+ struct Type<OptionalTensor>
+ {
+ static bool verify()
+ {
+ return true;
+ }
+ };
-class OptionalTensor {
- torch::Tensor value;
+ template<>
+ class From_Ruby<OptionalTensor>
+ {
public:
- OptionalTensor(Object o) {
- if (o.is_nil()) {
- value = {};
+ OptionalTensor convert(VALUE x)
+ {
+ return OptionalTensor(x);
+ }
+ };
+
+ template<>
+ struct Type<Scalar>
+ {
+ static bool verify()
+ {
+ return true;
+ }
+ };
+
+ template<>
+ class From_Ruby<Scalar>
+ {
+ public:
+ Scalar convert(VALUE x)
+ {
+ if (FIXNUM_P(x)) {
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
} else {
- value = from_ruby<torch::Tensor>(o);
+ return torch::Scalar(From_Ruby<double>().convert(x));
}
}
- OptionalTensor(torch::Tensor o) {
- value = o;
+ };
+
+ template<typename T>
+ struct Type<torch::optional<T>>
+ {
+ static bool verify()
+ {
+ return true;
}
- operator torch::Tensor() const {
- return value;
- }
-};
+ };
-template<>
-inline
-Scalar from_ruby<Scalar>(Object x)
-{
- if (x.rb_type() == T_FIXNUM) {
- return torch::Scalar(from_ruby<int64_t>(x));
- } else {
- return torch::Scalar(from_ruby<double>(x));
- }
-}
-
-template<>
-inline
-OptionalTensor from_ruby<OptionalTensor>(Object x)
-{
- return OptionalTensor(x);
-}
-
-template<>
-inline
-torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
-{
- if (x.is_nil()) {
- return torch::nullopt;
- } else {
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
- }
-}
-
-template<>
-inline
-torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
-{
- if (x.is_nil()) {
- return torch::nullopt;
- } else {
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
- }
-}
-
-template<>
-inline
-torch::optional<double> from_ruby<torch::optional<double>>(Object x)
-{
- if (x.is_nil()) {
- return torch::nullopt;
- } else {
- return torch::optional<double>{from_ruby<double>(x)};
- }
-}
-
-template<>
-inline
-torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
-{
- if (x.is_nil()) {
- return torch::nullopt;
- } else {
- return torch::optional<bool>{from_ruby<bool>(x)};
- }
-}
-
-template<>
-inline
-torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
-{
- if (x.is_nil()) {
- return torch::nullopt;
- } else {
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
- }
+ template<typename T>
+ class From_Ruby<torch::optional<T>>
+ {
+ public:
+ torch::optional<T> convert(VALUE x)
+ {
+ if (NIL_P(x)) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
+ }
+ }
+ };
}