ext/torch/templates.hpp in torch-rb-0.3.2 vs ext/torch/templates.hpp in torch-rb-0.3.3

- old
+ new

@@ -7,10 +7,14 @@ #include <rice/Array.hpp> #include <rice/Object.hpp> using namespace Rice; +using torch::Device; +using torch::ScalarType; +using torch::Tensor; + // need to wrap torch::IntArrayRef() since // it doesn't own underlying data class IntArrayRef { std::vector<int64_t> vec; public: @@ -172,12 +176,10 @@ MyReduction from_ruby<MyReduction>(Object x) { return MyReduction(x); } -typedef torch::Tensor Tensor; - class OptionalTensor { Object value; public: OptionalTensor(Object o) { value = o; @@ -195,49 +197,30 @@ OptionalTensor from_ruby<OptionalTensor>(Object x) { return OptionalTensor(x); } -class ScalarType { - Object value; - public: - ScalarType(Object o) { - value = o; - } - operator at::ScalarType() { - throw std::runtime_error("ScalarType arguments not implemented yet"); - } -}; - template<> inline -ScalarType from_ruby<ScalarType>(Object x) +torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x) { - return ScalarType(x); + if (x.is_nil()) { + return torch::nullopt; + } else { + return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)}; + } } -class OptionalScalarType { - Object value; - public: - OptionalScalarType(Object o) { - value = o; - } - operator c10::optional<at::ScalarType>() { - if (value.is_nil()) { - return c10::nullopt; - } - return ScalarType(value); - } -}; - template<> inline -OptionalScalarType from_ruby<OptionalScalarType>(Object x) +torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x) { - return OptionalScalarType(x); + if (x.is_nil()) { + return torch::nullopt; + } else { + return torch::optional<int64_t>{from_ruby<int64_t>(x)}; + } } - -typedef torch::Device Device; Object wrap(std::tuple<torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x); Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);