ext/torch/templates.hpp in torch-rb-0.1.7 vs ext/torch/templates.hpp in torch-rb-0.1.8

- old
+ new

@@ -1,7 +1,11 @@ #pragma once +#ifdef isfinite +#undef isfinite +#endif + #include <rice/Array.hpp> #include <rice/Object.hpp> using namespace Rice; @@ -77,16 +81,15 @@ std::string s; public: FanModeType(Object o) { s = String(o).str(); } - // TODO switch NonlinearityType after LibTorch 1.4 release - operator torch::nn::init::FanMode() { + operator torch::nn::init::FanModeType() { if (s == "fan_in") { - return torch::nn::init::FanMode::FanIn; + return torch::kFanIn; } else if (s == "fan_out") { - return torch::nn::init::FanMode::FanOut; + return torch::kFanOut; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } }; @@ -102,34 +105,33 @@ std::string s; public: NonlinearityType(Object o) { s = String(o).str(); } - // TODO switch NonlinearityType after LibTorch 1.4 release - operator torch::nn::init::Nonlinearity() { + operator torch::nn::init::NonlinearityType() { if (s == "linear") { - return torch::nn::init::Nonlinearity::Linear; + return torch::kLinear; } else if (s == "conv1d") { - return torch::nn::init::Nonlinearity::Conv1D; + return torch::kConv1D; } else if (s == "conv2d") { - return torch::nn::init::Nonlinearity::Conv2D; + return torch::kConv2D; } else if (s == "conv3d") { - return torch::nn::init::Nonlinearity::Conv3D; + return torch::kConv3D; } else if (s == "conv_transpose1d") { - return torch::nn::init::Nonlinearity::ConvTranspose1D; + return torch::kConvTranspose1D; } else if (s == "conv_transpose2d") { - return torch::nn::init::Nonlinearity::ConvTranspose2D; + return torch::kConvTranspose2D; } else if (s == "conv_transpose3d") { - return torch::nn::init::Nonlinearity::ConvTranspose3D; + return torch::kConvTranspose3D; } else if (s == "sigmoid") { - return torch::nn::init::Nonlinearity::Sigmoid; + return torch::kSigmoid; } else if (s == "tanh") { - return torch::nn::init::Nonlinearity::Tanh; + return torch::kTanh; } else if (s == "relu") { - return torch::nn::init::Nonlinearity::ReLU; + return torch::kReLU; } else if (s == "leaky_relu") { - return torch::nn::init::Nonlinearity::LeakyReLU; + return torch::kLeakyReLU; } else { throw std::runtime_error("Unsupported nonlinearity type: " + s); } } }; @@ -147,17 +149,17 @@ MyReduction(Object o) { value = o; } operator int64_t() { if (value.is_nil()) { - return Reduction::None; + return torch::Reduction::None; } std::string s = String(value).str(); if (s == "mean") { - return Reduction::Mean; + return torch::Reduction::Mean; } else if (s == "sum") { - return Reduction::Sum; + return torch::Reduction::Sum; } else { throw std::runtime_error("Unsupported reduction: " + s); } } };