ext/torch/templates.hpp in torch-rb-0.1.7 vs ext/torch/templates.hpp in torch-rb-0.1.8
- old
+ new
@@ -1,7 +1,11 @@
#pragma once
+#ifdef isfinite
+#undef isfinite
+#endif
+
#include <rice/Array.hpp>
#include <rice/Object.hpp>
using namespace Rice;
@@ -77,16 +81,15 @@
std::string s;
public:
FanModeType(Object o) {
s = String(o).str();
}
- // TODO switch NonlinearityType after LibTorch 1.4 release
- operator torch::nn::init::FanMode() {
+ operator torch::nn::init::FanModeType() {
if (s == "fan_in") {
- return torch::nn::init::FanMode::FanIn;
+ return torch::kFanIn;
} else if (s == "fan_out") {
- return torch::nn::init::FanMode::FanOut;
+ return torch::kFanOut;
} else {
throw std::runtime_error("Unsupported nonlinearity type: " + s);
}
}
};
@@ -102,34 +105,33 @@
std::string s;
public:
NonlinearityType(Object o) {
s = String(o).str();
}
- // TODO switch NonlinearityType after LibTorch 1.4 release
- operator torch::nn::init::Nonlinearity() {
+ operator torch::nn::init::NonlinearityType() {
if (s == "linear") {
- return torch::nn::init::Nonlinearity::Linear;
+ return torch::kLinear;
} else if (s == "conv1d") {
- return torch::nn::init::Nonlinearity::Conv1D;
+ return torch::kConv1D;
} else if (s == "conv2d") {
- return torch::nn::init::Nonlinearity::Conv2D;
+ return torch::kConv2D;
} else if (s == "conv3d") {
- return torch::nn::init::Nonlinearity::Conv3D;
+ return torch::kConv3D;
} else if (s == "conv_transpose1d") {
- return torch::nn::init::Nonlinearity::ConvTranspose1D;
+ return torch::kConvTranspose1D;
} else if (s == "conv_transpose2d") {
- return torch::nn::init::Nonlinearity::ConvTranspose2D;
+ return torch::kConvTranspose2D;
} else if (s == "conv_transpose3d") {
- return torch::nn::init::Nonlinearity::ConvTranspose3D;
+ return torch::kConvTranspose3D;
} else if (s == "sigmoid") {
- return torch::nn::init::Nonlinearity::Sigmoid;
+ return torch::kSigmoid;
} else if (s == "tanh") {
- return torch::nn::init::Nonlinearity::Tanh;
+ return torch::kTanh;
} else if (s == "relu") {
- return torch::nn::init::Nonlinearity::ReLU;
+ return torch::kReLU;
} else if (s == "leaky_relu") {
- return torch::nn::init::Nonlinearity::LeakyReLU;
+ return torch::kLeakyReLU;
} else {
throw std::runtime_error("Unsupported nonlinearity type: " + s);
}
}
};
@@ -147,17 +149,17 @@
MyReduction(Object o) {
value = o;
}
operator int64_t() {
if (value.is_nil()) {
- return Reduction::None;
+ return torch::Reduction::None;
}
std::string s = String(value).str();
if (s == "mean") {
- return Reduction::Mean;
+ return torch::Reduction::Mean;
} else if (s == "sum") {
- return Reduction::Sum;
+ return torch::Reduction::Sum;
} else {
throw std::runtime_error("Unsupported reduction: " + s);
}
}
};