ext/torch/templates.hpp in torch-rb-0.3.5 vs ext/torch/templates.hpp in torch-rb-0.3.6

- old
+ new

@@ -8,10 +8,11 @@ #include <rice/Object.hpp> using namespace Rice; using torch::Device; +using torch::Scalar; using torch::ScalarType; using torch::Tensor; // need to wrap torch::IntArrayRef() since // it doesn't own underlying data @@ -34,34 +35,10 @@ IntArrayRef from_ruby<IntArrayRef>(Object x) { return IntArrayRef(x); } -// for now -class Scalar { - torch::Scalar value; - public: - Scalar(Object o) { - // TODO cast based on Ruby type - if (o.rb_type() == T_FIXNUM) { - value = torch::Scalar(from_ruby<int64_t>(o)); - } else { - value = torch::Scalar(from_ruby<float>(o)); - } - } - operator torch::Scalar() { - return value; - } -}; - -template<> -inline -Scalar from_ruby<Scalar>(Object x) -{ - return Scalar(x); -} - class TensorList { std::vector<torch::Tensor> vec; public: TensorList(Object o) { Array a = Array(o); @@ -192,10 +169,21 @@ } }; template<> inline +Scalar from_ruby<Scalar>(Object x) +{ + if (x.rb_type() == T_FIXNUM) { + return torch::Scalar(from_ruby<int64_t>(x)); + } else { + return torch::Scalar(from_ruby<double>(x)); + } +} + +template<> +inline OptionalTensor from_ruby<OptionalTensor>(Object x) { return OptionalTensor(x); } @@ -219,11 +207,45 @@ } else { return torch::optional<int64_t>{from_ruby<int64_t>(x)}; } } +template<> +inline +torch::optional<double> from_ruby<torch::optional<double>>(Object x) +{ + if (x.is_nil()) { + return torch::nullopt; + } else { + return torch::optional<double>{from_ruby<double>(x)}; + } +} + +template<> +inline +torch::optional<bool> from_ruby<torch::optional<bool>>(Object x) +{ + if (x.is_nil()) { + return torch::nullopt; + } else { + return torch::optional<bool>{from_ruby<bool>(x)}; + } +} + +template<> +inline +torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x) +{ + if (x.is_nil()) { + return torch::nullopt; + } else { + return torch::optional<Scalar>{from_ruby<Scalar>(x)}; + } +} + Object wrap(std::tuple<torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x); +Object wrap(std::vector<torch::Tensor> x);