ext/torch/templates.hpp in torch-rb-0.1.5 vs ext/torch/templates.hpp in torch-rb-0.1.6
- old
+ new
@@ -246,5 +246,53 @@
inline
OptionalTensor from_ruby<OptionalTensor>(Object x)
{
return OptionalTensor(x);
}
+
+class ScalarType {
+ Object value;
+ public:
+ ScalarType(Object o) {
+ value = o;
+ }
+ operator at::ScalarType() {
+ throw std::runtime_error("ScalarType arguments not implemented yet");
+ }
+};
+
+template<>
+inline
+ScalarType from_ruby<ScalarType>(Object x)
+{
+ return ScalarType(x);
+}
+
+class OptionalScalarType {
+ Object value;
+ public:
+ OptionalScalarType(Object o) {
+ value = o;
+ }
+ operator c10::optional<at::ScalarType>() {
+ if (value.is_nil()) {
+ return c10::nullopt;
+ }
+ return ScalarType(value);
+ }
+};
+
+template<>
+inline
+OptionalScalarType from_ruby<OptionalScalarType>(Object x)
+{
+ return OptionalScalarType(x);
+}
+
+typedef torch::Device Device;
+
+Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
+Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
+Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
+Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
+Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
+Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);