#pragma once #include #include using namespace Rice; // need to wrap torch::IntArrayRef() since // it doesn't own underlying data class IntArrayRef { std::vector vec; public: IntArrayRef(Object o) { Array a = Array(o); for (size_t i = 0; i < a.size(); i++) { vec.push_back(from_ruby(a[i])); } } operator torch::IntArrayRef() { return torch::IntArrayRef(vec); } }; template<> inline IntArrayRef from_ruby(Object x) { return IntArrayRef(x); } // for now class Scalar { torch::Scalar value; public: Scalar(Object o) { // TODO cast based on Ruby type if (o.rb_type() == T_FIXNUM) { value = torch::Scalar(from_ruby(o)); } else { value = torch::Scalar(from_ruby(o)); } } operator torch::Scalar() { return value; } }; template<> inline Scalar from_ruby(Object x) { return Scalar(x); } class TensorList { std::vector vec; public: TensorList(Object o) { Array a = Array(o); for (size_t i = 0; i < a.size(); i++) { vec.push_back(from_ruby(a[i])); } } operator torch::TensorList() { return torch::TensorList(vec); } }; template<> inline TensorList from_ruby(Object x) { return TensorList(x); } class FanModeType { std::string s; public: FanModeType(Object o) { s = String(o).str(); } // TODO switch NonlinearityType after LibTorch 1.4 release operator torch::nn::init::FanMode() { if (s == "fan_in") { return torch::nn::init::FanMode::FanIn; } else if (s == "fan_out") { return torch::nn::init::FanMode::FanOut; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } }; template<> inline FanModeType from_ruby(Object x) { return FanModeType(x); } class NonlinearityType { std::string s; public: NonlinearityType(Object o) { s = String(o).str(); } // TODO switch NonlinearityType after LibTorch 1.4 release operator torch::nn::init::Nonlinearity() { if (s == "linear") { return torch::nn::init::Nonlinearity::Linear; } else if (s == "conv1d") { return torch::nn::init::Nonlinearity::Conv1D; } else if (s == "conv2d") { return torch::nn::init::Nonlinearity::Conv2D; } else if (s == "conv3d") { return torch::nn::init::Nonlinearity::Conv3D; } else if (s == "conv_transpose1d") { return torch::nn::init::Nonlinearity::ConvTranspose1D; } else if (s == "conv_transpose2d") { return torch::nn::init::Nonlinearity::ConvTranspose2D; } else if (s == "conv_transpose3d") { return torch::nn::init::Nonlinearity::ConvTranspose3D; } else if (s == "sigmoid") { return torch::nn::init::Nonlinearity::Sigmoid; } else if (s == "tanh") { return torch::nn::init::Nonlinearity::Tanh; } else if (s == "relu") { return torch::nn::init::Nonlinearity::ReLU; } else if (s == "leaky_relu") { return torch::nn::init::Nonlinearity::LeakyReLU; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } }; template<> inline NonlinearityType from_ruby(Object x) { return NonlinearityType(x); } class MyReduction { Object value; public: MyReduction(Object o) { value = o; } operator int64_t() { if (value.is_nil()) { return Reduction::None; } std::string s = String(value).str(); if (s == "mean") { return Reduction::Mean; } else if (s == "sum") { return Reduction::Sum; } else { throw std::runtime_error("Unsupported reduction: " + s); } } }; template<> inline MyReduction from_ruby(Object x) { return MyReduction(x); } typedef torch::Tensor Tensor; class OptionalTensor { Object value; public: OptionalTensor(Object o) { value = o; } operator torch::Tensor() { if (value.is_nil()) { return {}; } return from_ruby(value); } }; template<> inline OptionalTensor from_ruby(Object x) { return OptionalTensor(x); } class ScalarType { Object value; public: ScalarType(Object o) { value = o; } operator at::ScalarType() { throw std::runtime_error("ScalarType arguments not implemented yet"); } }; template<> inline ScalarType from_ruby(Object x) { return ScalarType(x); } class OptionalScalarType { Object value; public: OptionalScalarType(Object o) { value = o; } operator c10::optional() { if (value.is_nil()) { return c10::nullopt; } return ScalarType(value); } }; template<> inline OptionalScalarType from_ruby(Object x) { return OptionalScalarType(x); } typedef torch::Device Device; Object wrap(std::tuple x); Object wrap(std::tuple x); Object wrap(std::tuple x); Object wrap(std::tuple x); Object wrap(std::tuple x); Object wrap(std::tuple x);