ext/torch/templates.h in torch-rb-0.6.0 vs ext/torch/templates.h in torch-rb-0.7.0

- old
+ new

@@ -2,12 +2,11 @@ #ifdef isfinite #undef isfinite #endif -#include <rice/Array.hpp> -#include <rice/Object.hpp> +#include <rice/rice.hpp> using namespace Rice; using torch::Device; using torch::Scalar; @@ -21,10 +20,13 @@ using torch::IntArrayRef; using torch::ArrayRef; using torch::TensorList; using torch::Storage; +using torch::nn::init::FanModeType; +using torch::nn::init::NonlinearityType; + #define HANDLE_TH_ERRORS \ try { #define END_HANDLE_TH_ERRORS \ } catch (const torch::Error& ex) { \ @@ -36,65 +38,72 @@ } #define RETURN_NIL \ return Qnil; -template<> -inline -std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x) -{ - Array a = Array(x); - std::vector<int64_t> vec(a.size()); - for (long i = 0; i < a.size(); i++) { - vec[i] = from_ruby<int64_t>(a[i]); - } - return vec; -} +class OptionalTensor { + torch::Tensor value; + public: + OptionalTensor(Object o) { + if (o.is_nil()) { + value = {}; + } else { + value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value()); + } + } + OptionalTensor(torch::Tensor o) { + value = o; + } + operator torch::Tensor() const { + return value; + } +}; -template<> -inline -std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x) +namespace Rice::detail { - Array a = Array(x); - std::vector<Tensor> vec(a.size()); - for (long i = 0; i < a.size(); i++) { - vec[i] = from_ruby<Tensor>(a[i]); - } - return vec; -} + template<> + struct Type<FanModeType> + { + static bool verify() + { + return true; + } + }; -class FanModeType { - std::string s; + template<> + class From_Ruby<FanModeType> + { public: - FanModeType(Object o) { - s = String(o).str(); - } - operator torch::nn::init::FanModeType() { + FanModeType convert(VALUE x) + { + auto s = String(x).str(); if (s == "fan_in") { return torch::kFanIn; } else if (s == "fan_out") { return torch::kFanOut; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } -}; + }; -template<> -inline -FanModeType from_ruby<FanModeType>(Object x) -{ - return FanModeType(x); -} + template<> + struct Type<NonlinearityType> + { + static bool verify() + { + return true; + } + }; -class NonlinearityType { - std::string s; + template<> + class From_Ruby<NonlinearityType> + { public: - NonlinearityType(Object o) { - s = String(o).str(); - } - operator torch::nn::init::NonlinearityType() { + NonlinearityType convert(VALUE x) + { + auto s = String(x).str(); if (s == "linear") { return torch::kLinear; } else if (s == "conv1d") { return torch::kConv1D; } else if (s == "conv2d") { @@ -117,104 +126,72 @@ return torch::kLeakyReLU; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } -}; + }; -template<> -inline -NonlinearityType from_ruby<NonlinearityType>(Object x) -{ - return NonlinearityType(x); -} + template<> + struct Type<OptionalTensor> + { + static bool verify() + { + return true; + } + }; -class OptionalTensor { - torch::Tensor value; + template<> + class From_Ruby<OptionalTensor> + { public: - OptionalTensor(Object o) { - if (o.is_nil()) { - value = {}; + OptionalTensor convert(VALUE x) + { + return OptionalTensor(x); + } + }; + + template<> + struct Type<Scalar> + { + static bool verify() + { + return true; + } + }; + + template<> + class From_Ruby<Scalar> + { + public: + Scalar convert(VALUE x) + { + if (FIXNUM_P(x)) { + return torch::Scalar(From_Ruby<int64_t>().convert(x)); } else { - value = from_ruby<torch::Tensor>(o); + return torch::Scalar(From_Ruby<double>().convert(x)); } } - OptionalTensor(torch::Tensor o) { - value = o; + }; + + template<typename T> + struct Type<torch::optional<T>> + { + static bool verify() + { + return true; } - operator torch::Tensor() const { - return value; - } -}; + }; -template<> -inline -Scalar from_ruby<Scalar>(Object x) -{ - if (x.rb_type() == T_FIXNUM) { - return torch::Scalar(from_ruby<int64_t>(x)); - } else { - return torch::Scalar(from_ruby<double>(x)); - } -} - -template<> -inline -OptionalTensor from_ruby<OptionalTensor>(Object x) -{ - return OptionalTensor(x); -} - -template<> -inline -torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x) -{ - if (x.is_nil()) { - return torch::nullopt; - } else { - return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)}; - } -} - -template<> -inline -torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x) -{ - if (x.is_nil()) { - return torch::nullopt; - } else { - return torch::optional<int64_t>{from_ruby<int64_t>(x)}; - } -} - -template<> -inline -torch::optional<double> from_ruby<torch::optional<double>>(Object x) -{ - if (x.is_nil()) { - return torch::nullopt; - } else { - return torch::optional<double>{from_ruby<double>(x)}; - } -} - -template<> -inline -torch::optional<bool> from_ruby<torch::optional<bool>>(Object x) -{ - if (x.is_nil()) { - return torch::nullopt; - } else { - return torch::optional<bool>{from_ruby<bool>(x)}; - } -} - -template<> -inline -torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x) -{ - if (x.is_nil()) { - return torch::nullopt; - } else { - return torch::optional<Scalar>{from_ruby<Scalar>(x)}; - } + template<typename T> + class From_Ruby<torch::optional<T>> + { + public: + torch::optional<T> convert(VALUE x) + { + if (NIL_P(x)) { + return torch::nullopt; + } else { + return torch::optional<T>{From_Ruby<T>().convert(x)}; + } + } + }; }