vendor/fastText/src/args.cc in fasttext-0.1.2 vs vendor/fastText/src/args.cc in fasttext-0.1.3
- old
+ new
@@ -10,10 +10,12 @@
#include <stdlib.h>
#include <iostream>
#include <stdexcept>
+#include <string>
+#include <unordered_map>
namespace fasttext {
Args::Args() {
lr = 0.05;
@@ -34,16 +36,23 @@
t = 1e-4;
label = "__label__";
verbose = 2;
pretrainedVectors = "";
saveOutput = false;
+ seed = 0;
qout = false;
retrain = false;
qnorm = false;
cutoff = 0;
dsub = 2;
+
+ autotuneValidationFile = "";
+ autotuneMetric = "f1";
+ autotunePredictions = 1;
+ autotuneDuration = 60 * 5; // 5 minutes
+ autotuneModelSize = "";
}
std::string Args::lossToString(loss_name ln) const {
switch (ln) {
case loss_name::hs:
@@ -76,10 +85,28 @@
return "sup";
}
return "Unknown model name!"; // should never happen
}
+std::string Args::metricToString(metric_name mn) const {
+ switch (mn) {
+ case metric_name::f1score:
+ return "f1score";
+ case metric_name::f1scoreLabel:
+ return "f1scoreLabel";
+ case metric_name::precisionAtRecall:
+ return "precisionAtRecall";
+ case metric_name::precisionAtRecallLabel:
+ return "precisionAtRecallLabel";
+ case metric_name::recallAtPrecision:
+ return "recallAtPrecision";
+ case metric_name::recallAtPrecisionLabel:
+ return "recallAtPrecisionLabel";
+ }
+ return "Unknown metric name!"; // should never happen
+}
+
void Args::parseArgs(const std::vector<std::string>& args) {
std::string command(args[1]);
if (command == "supervised") {
model = model_name::sup;
loss = loss_name::softmax;
@@ -95,10 +122,12 @@
std::cerr << "Provided argument without a dash! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
}
try {
+ setManual(args[ai].substr(1));
+
if (args[ai] == "-h") {
std::cerr << "Here is the help! Usage:" << std::endl;
printHelp();
exit(EXIT_FAILURE);
} else if (args[ai] == "-input") {
@@ -155,10 +184,12 @@
} else if (args[ai] == "-pretrainedVectors") {
pretrainedVectors = std::string(args.at(ai + 1));
} else if (args[ai] == "-saveOutput") {
saveOutput = true;
ai--;
+ } else if (args[ai] == "-seed") {
+ seed = std::stoi(args.at(ai + 1));
} else if (args[ai] == "-qnorm") {
qnorm = true;
ai--;
} else if (args[ai] == "-retrain") {
retrain = true;
@@ -168,10 +199,22 @@
ai--;
} else if (args[ai] == "-cutoff") {
cutoff = std::stoi(args.at(ai + 1));
} else if (args[ai] == "-dsub") {
dsub = std::stoi(args.at(ai + 1));
+ } else if (args[ai] == "-autotune-validation") {
+ autotuneValidationFile = std::string(args.at(ai + 1));
+ } else if (args[ai] == "-autotune-metric") {
+ autotuneMetric = std::string(args.at(ai + 1));
+ getAutotuneMetric(); // throws exception if not able to parse
+ getAutotuneMetricLabel(); // throws exception if not able to parse
+ } else if (args[ai] == "-autotune-predictions") {
+ autotunePredictions = std::stoi(args.at(ai + 1));
+ } else if (args[ai] == "-autotune-duration") {
+ autotuneDuration = std::stoi(args.at(ai + 1));
+ } else if (args[ai] == "-autotune-modelsize") {
+ autotuneModelSize = std::string(args.at(ai + 1));
} else {
std::cerr << "Unknown argument: " << args[ai] << std::endl;
printHelp();
exit(EXIT_FAILURE);
}
@@ -184,19 +227,20 @@
if (input.empty() || output.empty()) {
std::cerr << "Empty input or output path." << std::endl;
printHelp();
exit(EXIT_FAILURE);
}
- if (wordNgrams <= 1 && maxn == 0) {
+ if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
bucket = 0;
}
}
void Args::printHelp() {
printBasicHelp();
printDictionaryHelp();
printTrainingHelp();
+ printAutotuneHelp();
printQuantizationHelp();
}
void Args::printBasicHelp() {
std::cerr << "\nThe following arguments are mandatory:\n"
@@ -225,31 +269,53 @@
void Args::printTrainingHelp() {
std::cerr
<< "\nThe following arguments for training are optional:\n"
<< " -lr learning rate [" << lr << "]\n"
- << " -lrUpdateRate change the rate of updates for the learning rate ["
+ << " -lrUpdateRate change the rate of updates for the learning "
+ "rate ["
<< lrUpdateRate << "]\n"
<< " -dim size of word vectors [" << dim << "]\n"
<< " -ws size of the context window [" << ws << "]\n"
<< " -epoch number of epochs [" << epoch << "]\n"
<< " -neg number of negatives sampled [" << neg << "]\n"
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
<< lossToString(loss) << "]\n"
- << " -thread number of threads [" << thread << "]\n"
- << " -pretrainedVectors pretrained word vectors for supervised learning ["
+ << " -thread number of threads (set to 1 to ensure "
+ "reproducible results) ["
+ << thread << "]\n"
+ << " -pretrainedVectors pretrained word vectors for supervised "
+ "learning ["
<< pretrainedVectors << "]\n"
<< " -saveOutput whether output params should be saved ["
- << boolToString(saveOutput) << "]\n";
+ << boolToString(saveOutput) << "]\n"
+ << " -seed random generator seed [" << seed << "]\n";
}
+void Args::printAutotuneHelp() {
+ std::cerr << "\nThe following arguments are for autotune:\n"
+ << " -autotune-validation validation file to be used "
+ "for evaluation\n"
+ << " -autotune-metric metric objective {f1, "
+ "f1:labelname} ["
+ << autotuneMetric << "]\n"
+ << " -autotune-predictions number of predictions used "
+ "for evaluation ["
+ << autotunePredictions << "]\n"
+ << " -autotune-duration maximum duration in seconds ["
+ << autotuneDuration << "]\n"
+ << " -autotune-modelsize constraint model file size ["
+ << autotuneModelSize << "] (empty = do not quantize)\n";
+}
+
void Args::printQuantizationHelp() {
std::cerr
<< "\nThe following arguments for quantization are optional:\n"
<< " -cutoff number of words and ngrams to retain ["
<< cutoff << "]\n"
- << " -retrain whether embeddings are finetuned if a cutoff is applied ["
+ << " -retrain whether embeddings are finetuned if a cutoff "
+ "is applied ["
<< boolToString(retrain) << "]\n"
<< " -qnorm whether the norm is quantized separately ["
<< boolToString(qnorm) << "]\n"
<< " -qout whether the classifier is quantized ["
<< boolToString(qout) << "]\n"
@@ -313,8 +379,115 @@
<< " " << maxn << std::endl;
out << "lrUpdateRate"
<< " " << lrUpdateRate << std::endl;
out << "t"
<< " " << t << std::endl;
+}
+
+bool Args::hasAutotune() const {
+ return !autotuneValidationFile.empty();
+}
+
+bool Args::isManual(const std::string& argName) const {
+ return (manualArgs_.count(argName) != 0);
+}
+
+void Args::setManual(const std::string& argName) {
+ manualArgs_.emplace(argName);
+}
+
+metric_name Args::getAutotuneMetric() const {
+ if (autotuneMetric.substr(0, 3) == "f1:") {
+ return metric_name::f1scoreLabel;
+ } else if (autotuneMetric == "f1") {
+ return metric_name::f1score;
+ } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
+ size_t semicolon = autotuneMetric.find(":", 18);
+ if (semicolon != std::string::npos) {
+ return metric_name::precisionAtRecallLabel;
+ }
+ return metric_name::precisionAtRecall;
+ } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
+ size_t semicolon = autotuneMetric.find(":", 18);
+ if (semicolon != std::string::npos) {
+ return metric_name::recallAtPrecisionLabel;
+ }
+ return metric_name::recallAtPrecision;
+ }
+ throw std::runtime_error("Unknown metric : " + autotuneMetric);
+}
+
+std::string Args::getAutotuneMetricLabel() const {
+ metric_name metric = getAutotuneMetric();
+ std::string label;
+ if (metric == metric_name::f1scoreLabel) {
+ label = autotuneMetric.substr(3);
+ } else if (
+ metric == metric_name::precisionAtRecallLabel ||
+ metric == metric_name::recallAtPrecisionLabel) {
+ size_t semicolon = autotuneMetric.find(":", 18);
+ label = autotuneMetric.substr(semicolon + 1);
+ } else {
+ return label;
+ }
+
+ if (label.empty()) {
+ throw std::runtime_error("Empty metric label : " + autotuneMetric);
+ }
+ return label;
+}
+
+double Args::getAutotuneMetricValue() const {
+ metric_name metric = getAutotuneMetric();
+ double value = 0.0;
+ if (metric == metric_name::precisionAtRecallLabel ||
+ metric == metric_name::precisionAtRecall ||
+ metric == metric_name::recallAtPrecisionLabel ||
+ metric == metric_name::recallAtPrecision) {
+ size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
+ size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
+ const std::string valueStr =
+ autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
+ value = std::stof(valueStr) / 100.0;
+ }
+ return value;
+}
+
+int64_t Args::getAutotuneModelSize() const {
+ std::string modelSize = autotuneModelSize;
+ if (modelSize.empty()) {
+ return Args::kUnlimitedModelSize;
+ }
+ std::unordered_map<char, int> units = {
+ {'k', 1000},
+ {'K', 1000},
+ {'m', 1000000},
+ {'M', 1000000},
+ {'g', 1000000000},
+ {'G', 1000000000},
+ };
+ uint64_t multiplier = 1;
+ char lastCharacter = modelSize.back();
+ if (units.count(lastCharacter)) {
+ multiplier = units[lastCharacter];
+ modelSize = modelSize.substr(0, modelSize.size() - 1);
+ }
+ uint64_t size = 0;
+ size_t nonNumericCharacter = 0;
+ bool parseError = false;
+ try {
+ size = std::stol(modelSize, &nonNumericCharacter);
+ } catch (std::invalid_argument&) {
+ parseError = true;
+ }
+ if (!parseError && nonNumericCharacter != modelSize.size()) {
+ parseError = true;
+ }
+ if (parseError) {
+ throw std::invalid_argument(
+ "Unable to parse model size " + autotuneModelSize);
+ }
+
+ return size * multiplier;
}
} // namespace fasttext