vendor/fastText/src/args.cc in fasttext-0.1.2 vs vendor/fastText/src/args.cc in fasttext-0.1.3

- old
+ new

@@ -10,10 +10,12 @@ #include <stdlib.h> #include <iostream> #include <stdexcept> +#include <string> +#include <unordered_map> namespace fasttext { Args::Args() { lr = 0.05; @@ -34,16 +36,23 @@ t = 1e-4; label = "__label__"; verbose = 2; pretrainedVectors = ""; saveOutput = false; + seed = 0; qout = false; retrain = false; qnorm = false; cutoff = 0; dsub = 2; + + autotuneValidationFile = ""; + autotuneMetric = "f1"; + autotunePredictions = 1; + autotuneDuration = 60 * 5; // 5 minutes + autotuneModelSize = ""; } std::string Args::lossToString(loss_name ln) const { switch (ln) { case loss_name::hs: @@ -76,10 +85,28 @@ return "sup"; } return "Unknown model name!"; // should never happen } +std::string Args::metricToString(metric_name mn) const { + switch (mn) { + case metric_name::f1score: + return "f1score"; + case metric_name::f1scoreLabel: + return "f1scoreLabel"; + case metric_name::precisionAtRecall: + return "precisionAtRecall"; + case metric_name::precisionAtRecallLabel: + return "precisionAtRecallLabel"; + case metric_name::recallAtPrecision: + return "recallAtPrecision"; + case metric_name::recallAtPrecisionLabel: + return "recallAtPrecisionLabel"; + } + return "Unknown metric name!"; // should never happen +} + void Args::parseArgs(const std::vector<std::string>& args) { std::string command(args[1]); if (command == "supervised") { model = model_name::sup; loss = loss_name::softmax; @@ -95,10 +122,12 @@ std::cerr << "Provided argument without a dash! Usage:" << std::endl; printHelp(); exit(EXIT_FAILURE); } try { + setManual(args[ai].substr(1)); + if (args[ai] == "-h") { std::cerr << "Here is the help! Usage:" << std::endl; printHelp(); exit(EXIT_FAILURE); } else if (args[ai] == "-input") { @@ -155,10 +184,12 @@ } else if (args[ai] == "-pretrainedVectors") { pretrainedVectors = std::string(args.at(ai + 1)); } else if (args[ai] == "-saveOutput") { saveOutput = true; ai--; + } else if (args[ai] == "-seed") { + seed = std::stoi(args.at(ai + 1)); } else if (args[ai] == "-qnorm") { qnorm = true; ai--; } else if (args[ai] == "-retrain") { retrain = true; @@ -168,10 +199,22 @@ ai--; } else if (args[ai] == "-cutoff") { cutoff = std::stoi(args.at(ai + 1)); } else if (args[ai] == "-dsub") { dsub = std::stoi(args.at(ai + 1)); + } else if (args[ai] == "-autotune-validation") { + autotuneValidationFile = std::string(args.at(ai + 1)); + } else if (args[ai] == "-autotune-metric") { + autotuneMetric = std::string(args.at(ai + 1)); + getAutotuneMetric(); // throws exception if not able to parse + getAutotuneMetricLabel(); // throws exception if not able to parse + } else if (args[ai] == "-autotune-predictions") { + autotunePredictions = std::stoi(args.at(ai + 1)); + } else if (args[ai] == "-autotune-duration") { + autotuneDuration = std::stoi(args.at(ai + 1)); + } else if (args[ai] == "-autotune-modelsize") { + autotuneModelSize = std::string(args.at(ai + 1)); } else { std::cerr << "Unknown argument: " << args[ai] << std::endl; printHelp(); exit(EXIT_FAILURE); } @@ -184,19 +227,20 @@ if (input.empty() || output.empty()) { std::cerr << "Empty input or output path." << std::endl; printHelp(); exit(EXIT_FAILURE); } - if (wordNgrams <= 1 && maxn == 0) { + if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) { bucket = 0; } } void Args::printHelp() { printBasicHelp(); printDictionaryHelp(); printTrainingHelp(); + printAutotuneHelp(); printQuantizationHelp(); } void Args::printBasicHelp() { std::cerr << "\nThe following arguments are mandatory:\n" @@ -225,31 +269,53 @@ void Args::printTrainingHelp() { std::cerr << "\nThe following arguments for training are optional:\n" << " -lr learning rate [" << lr << "]\n" - << " -lrUpdateRate change the rate of updates for the learning rate [" + << " -lrUpdateRate change the rate of updates for the learning " + "rate [" << lrUpdateRate << "]\n" << " -dim size of word vectors [" << dim << "]\n" << " -ws size of the context window [" << ws << "]\n" << " -epoch number of epochs [" << epoch << "]\n" << " -neg number of negatives sampled [" << neg << "]\n" << " -loss loss function {ns, hs, softmax, one-vs-all} [" << lossToString(loss) << "]\n" - << " -thread number of threads [" << thread << "]\n" - << " -pretrainedVectors pretrained word vectors for supervised learning [" + << " -thread number of threads (set to 1 to ensure " + "reproducible results) [" + << thread << "]\n" + << " -pretrainedVectors pretrained word vectors for supervised " + "learning [" << pretrainedVectors << "]\n" << " -saveOutput whether output params should be saved [" - << boolToString(saveOutput) << "]\n"; + << boolToString(saveOutput) << "]\n" + << " -seed random generator seed [" << seed << "]\n"; } +void Args::printAutotuneHelp() { + std::cerr << "\nThe following arguments are for autotune:\n" + << " -autotune-validation validation file to be used " + "for evaluation\n" + << " -autotune-metric metric objective {f1, " + "f1:labelname} [" + << autotuneMetric << "]\n" + << " -autotune-predictions number of predictions used " + "for evaluation [" + << autotunePredictions << "]\n" + << " -autotune-duration maximum duration in seconds [" + << autotuneDuration << "]\n" + << " -autotune-modelsize constraint model file size [" + << autotuneModelSize << "] (empty = do not quantize)\n"; +} + void Args::printQuantizationHelp() { std::cerr << "\nThe following arguments for quantization are optional:\n" << " -cutoff number of words and ngrams to retain [" << cutoff << "]\n" - << " -retrain whether embeddings are finetuned if a cutoff is applied [" + << " -retrain whether embeddings are finetuned if a cutoff " + "is applied [" << boolToString(retrain) << "]\n" << " -qnorm whether the norm is quantized separately [" << boolToString(qnorm) << "]\n" << " -qout whether the classifier is quantized [" << boolToString(qout) << "]\n" @@ -313,8 +379,115 @@ << " " << maxn << std::endl; out << "lrUpdateRate" << " " << lrUpdateRate << std::endl; out << "t" << " " << t << std::endl; +} + +bool Args::hasAutotune() const { + return !autotuneValidationFile.empty(); +} + +bool Args::isManual(const std::string& argName) const { + return (manualArgs_.count(argName) != 0); +} + +void Args::setManual(const std::string& argName) { + manualArgs_.emplace(argName); +} + +metric_name Args::getAutotuneMetric() const { + if (autotuneMetric.substr(0, 3) == "f1:") { + return metric_name::f1scoreLabel; + } else if (autotuneMetric == "f1") { + return metric_name::f1score; + } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") { + size_t semicolon = autotuneMetric.find(":", 18); + if (semicolon != std::string::npos) { + return metric_name::precisionAtRecallLabel; + } + return metric_name::precisionAtRecall; + } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") { + size_t semicolon = autotuneMetric.find(":", 18); + if (semicolon != std::string::npos) { + return metric_name::recallAtPrecisionLabel; + } + return metric_name::recallAtPrecision; + } + throw std::runtime_error("Unknown metric : " + autotuneMetric); +} + +std::string Args::getAutotuneMetricLabel() const { + metric_name metric = getAutotuneMetric(); + std::string label; + if (metric == metric_name::f1scoreLabel) { + label = autotuneMetric.substr(3); + } else if ( + metric == metric_name::precisionAtRecallLabel || + metric == metric_name::recallAtPrecisionLabel) { + size_t semicolon = autotuneMetric.find(":", 18); + label = autotuneMetric.substr(semicolon + 1); + } else { + return label; + } + + if (label.empty()) { + throw std::runtime_error("Empty metric label : " + autotuneMetric); + } + return label; +} + +double Args::getAutotuneMetricValue() const { + metric_name metric = getAutotuneMetric(); + double value = 0.0; + if (metric == metric_name::precisionAtRecallLabel || + metric == metric_name::precisionAtRecall || + metric == metric_name::recallAtPrecisionLabel || + metric == metric_name::recallAtPrecision) { + size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:" + size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon); + const std::string valueStr = + autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon); + value = std::stof(valueStr) / 100.0; + } + return value; +} + +int64_t Args::getAutotuneModelSize() const { + std::string modelSize = autotuneModelSize; + if (modelSize.empty()) { + return Args::kUnlimitedModelSize; + } + std::unordered_map<char, int> units = { + {'k', 1000}, + {'K', 1000}, + {'m', 1000000}, + {'M', 1000000}, + {'g', 1000000000}, + {'G', 1000000000}, + }; + uint64_t multiplier = 1; + char lastCharacter = modelSize.back(); + if (units.count(lastCharacter)) { + multiplier = units[lastCharacter]; + modelSize = modelSize.substr(0, modelSize.size() - 1); + } + uint64_t size = 0; + size_t nonNumericCharacter = 0; + bool parseError = false; + try { + size = std::stol(modelSize, &nonNumericCharacter); + } catch (std::invalid_argument&) { + parseError = true; + } + if (!parseError && nonNumericCharacter != modelSize.size()) { + parseError = true; + } + if (parseError) { + throw std::invalid_argument( + "Unable to parse model size " + autotuneModelSize); + } + + return size * multiplier; } } // namespace fasttext