vendor/fastText/src/main.cc in fasttext-0.1.2 vs vendor/fastText/src/main.cc in fasttext-0.1.3

- old
+ new

@@ -9,32 +9,39 @@ #include <iomanip> #include <iostream> #include <queue> #include <stdexcept> #include "args.h" +#include "autotune.h" #include "fasttext.h" using namespace fasttext; void printUsage() { std::cerr << "usage: fasttext <command> <args>\n\n" << "The commands supported by fasttext are:\n\n" << " supervised train a supervised classifier\n" - << " quantize quantize a model to reduce the memory usage\n" + << " quantize quantize a model to reduce the memory " + "usage\n" << " test evaluate a supervised classifier\n" - << " test-label print labels with precision and recall scores\n" + << " test-label print labels with precision and recall " + "scores\n" << " predict predict most likely labels\n" - << " predict-prob predict most likely labels with probabilities\n" + << " predict-prob predict most likely labels with " + "probabilities\n" << " skipgram train a skipgram model\n" << " cbow train a cbow model\n" << " print-word-vectors print word vectors given a trained model\n" - << " print-sentence-vectors print sentence vectors given a trained model\n" - << " print-ngrams print ngrams given a trained model and word\n" + << " print-sentence-vectors print sentence vectors given a trained " + "model\n" + << " print-ngrams print ngrams given a trained model and " + "word\n" << " nn query for nearest neighbors\n" << " analogies query for analogies\n" - << " dump dump arguments,dictionary,input/output vectors\n" + << " dump dump arguments,dictionary,input/output " + "vectors\n" << std::endl; } void printQuantizeUsage() { std::cerr << "usage: fasttext quantize <args>" << std::endl; @@ -139,11 +146,11 @@ real threshold = args.size() > 5 ? std::stof(args[5]) : 0.0; FastText fasttext; fasttext.loadModel(model); - Meter meter; + Meter meter(false); if (input == "-") { fasttext.test(std::cin, k, threshold, meter); } else { std::ifstream ifs(input); @@ -349,22 +356,34 @@ } void train(const std::vector<std::string> args) { Args a = Args(); a.parseArgs(args); - FastText fasttext; - std::string outputFileName(a.output + ".bin"); + std::shared_ptr<FastText> fasttext = std::make_shared<FastText>(); + std::string outputFileName; + + if (a.hasAutotune() && + a.getAutotuneModelSize() != Args::kUnlimitedModelSize) { + outputFileName = a.output + ".ftz"; + } else { + outputFileName = a.output + ".bin"; + } std::ofstream ofs(outputFileName); if (!ofs.is_open()) { throw std::invalid_argument( outputFileName + " cannot be opened for saving."); } ofs.close(); - fasttext.train(a); - fasttext.saveModel(outputFileName); - fasttext.saveVectors(a.output + ".vec"); + if (a.hasAutotune()) { + Autotune autotune(fasttext); + autotune.train(a); + } else { + fasttext->train(a); + } + fasttext->saveModel(outputFileName); + fasttext->saveVectors(a.output + ".vec"); if (a.saveOutput) { - fasttext.saveOutput(a.output + ".output"); + fasttext->saveOutput(a.output + ".output"); } } void dump(const std::vector<std::string>& args) { if (args.size() < 4) {