ext/torch/templates.hpp in torch-rb-0.1.5 vs ext/torch/templates.hpp in torch-rb-0.1.6

- old
+ new

@@ -246,5 +246,53 @@ inline OptionalTensor from_ruby<OptionalTensor>(Object x) { return OptionalTensor(x); } + +class ScalarType { + Object value; + public: + ScalarType(Object o) { + value = o; + } + operator at::ScalarType() { + throw std::runtime_error("ScalarType arguments not implemented yet"); + } +}; + +template<> +inline +ScalarType from_ruby<ScalarType>(Object x) +{ + return ScalarType(x); +} + +class OptionalScalarType { + Object value; + public: + OptionalScalarType(Object o) { + value = o; + } + operator c10::optional<at::ScalarType>() { + if (value.is_nil()) { + return c10::nullopt; + } + return ScalarType(value); + } +}; + +template<> +inline +OptionalScalarType from_ruby<OptionalScalarType>(Object x) +{ + return OptionalScalarType(x); +} + +typedef torch::Device Device; + +Object wrap(std::tuple<torch::Tensor, torch::Tensor> x); +Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x); +Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); +Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); +Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x); +Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);