vendor/fastText/src/fasttext.cc in fasttext-0.1.2 vs vendor/fastText/src/fasttext.cc in fasttext-0.1.3
- old
+ new
@@ -45,11 +45,12 @@
default:
throw std::runtime_error("Unknown loss");
}
}
-FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
+FastText::FastText()
+ : quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
void FastText::addInputVector(Vector& vec, int32_t ind) const {
vec.addRow(*input_, ind);
}
@@ -67,10 +68,23 @@
}
assert(input_.get());
return std::dynamic_pointer_cast<DenseMatrix>(input_);
}
+void FastText::setMatrices(
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
+ const std::shared_ptr<DenseMatrix>& outputMatrix) {
+ assert(input_->size(1) == output_->size(1));
+
+ input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
+ output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
+ wordVectors_.reset();
+ args_->dim = input_->size(1);
+
+ buildModel();
+}
+
std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
if (quant_ && args_->qout) {
throw std::runtime_error("Can't export quantized matrix");
}
assert(output_.get());
@@ -84,10 +98,18 @@
int32_t FastText::getSubwordId(const std::string& subword) const {
int32_t h = dict_->hash(subword) % args_->bucket;
return dict_->nwords() + h;
}
+int32_t FastText::getLabelId(const std::string& label) const {
+ int32_t labelId = dict_->getId(label);
+ if (labelId != -1) {
+ labelId -= dict_->nwords();
+ }
+ return labelId;
+}
+
void FastText::getWordVector(Vector& vec, const std::string& word) const {
const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
vec.zero();
for (int i = 0; i < ngrams.size(); i++) {
addInputVector(vec, ngrams[i]);
@@ -95,22 +117,21 @@
if (ngrams.size() > 0) {
vec.mul(1.0 / ngrams.size());
}
}
-void FastText::getVector(Vector& vec, const std::string& word) const {
- getWordVector(vec, word);
-}
-
void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
vec.zero();
int32_t h = dict_->hash(subword) % args_->bucket;
h = h + dict_->nwords();
addInputVector(vec, h);
}
void FastText::saveVectors(const std::string& filename) {
+ if (!input_ || !output_) {
+ throw std::runtime_error("Model never trained");
+ }
std::ofstream ofs(filename);
if (!ofs.is_open()) {
throw std::invalid_argument(
filename + " cannot be opened for saving vectors!");
}
@@ -122,14 +143,10 @@
ofs << word << " " << vec << std::endl;
}
ofs.close();
}
-void FastText::saveVectors() {
- saveVectors(args_->output + ".vec");
-}
-
void FastText::saveOutput(const std::string& filename) {
std::ofstream ofs(filename);
if (!ofs.is_open()) {
throw std::invalid_argument(
filename + " cannot be opened for saving vectors!");
@@ -150,14 +167,10 @@
ofs << word << " " << vec << std::endl;
}
ofs.close();
}
-void FastText::saveOutput() {
- saveOutput(args_->output + ".output");
-}
-
bool FastText::checkModel(std::istream& in) {
int32_t magic;
in.read((char*)&(magic), sizeof(int32_t));
if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) {
return false;
@@ -174,25 +187,18 @@
const int32_t version = FASTTEXT_VERSION;
out.write((char*)&(magic), sizeof(int32_t));
out.write((char*)&(version), sizeof(int32_t));
}
-void FastText::saveModel() {
- std::string fn(args_->output);
- if (quant_) {
- fn += ".ftz";
- } else {
- fn += ".bin";
- }
- saveModel(fn);
-}
-
void FastText::saveModel(const std::string& filename) {
std::ofstream ofs(filename, std::ofstream::binary);
if (!ofs.is_open()) {
throw std::invalid_argument(filename + " cannot be opened for saving!");
}
+ if (!input_ || !output_) {
+ throw std::runtime_error("Model never trained");
+ }
signModel(ofs);
args_->save(ofs);
dict_->save(ofs);
ofs.write((char*)&(quant_), sizeof(bool));
@@ -222,10 +228,16 @@
} else {
return dict_->getCounts(entry_type::word);
}
}
+void FastText::buildModel() {
+ auto loss = createLoss(output_);
+ bool normalizeGradient = (args_->model == model_name::sup);
+ model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
+}
+
void FastText::loadModel(std::istream& in) {
args_ = std::make_shared<Args>();
input_ = std::make_shared<DenseMatrix>();
output_ = std::make_shared<DenseMatrix>();
args_->load(in);
@@ -254,41 +266,41 @@
if (quant_ && args_->qout) {
output_ = std::make_shared<QuantMatrix>();
}
output_->load(in);
- auto loss = createLoss(output_);
- bool normalizeGradient = (args_->model == model_name::sup);
- model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
+ buildModel();
}
-void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
- std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
- double t =
- std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
- .count();
+std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
+ double t = utils::getDuration(start_, std::chrono::steady_clock::now());
double lr = args_->lr * (1.0 - progress);
double wst = 0;
int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
if (progress > 0 && t >= 0) {
- progress = progress * 100;
- eta = t * (100 - progress) / progress;
+ eta = t * (1 - progress) / progress;
wst = double(tokenCount_) / t / args_->thread;
}
- int32_t etah = eta / 3600;
- int32_t etam = (eta % 3600) / 60;
+ return std::tuple<double, double, int64_t>(wst, lr, eta);
+}
+
+void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
+ double wst;
+ double lr;
+ int64_t eta;
+ std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
+
log_stream << std::fixed;
log_stream << "Progress: ";
- log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
+ log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
- log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
- log_stream << " ETA: " << std::setw(3) << etah;
- log_stream << "h" << std::setw(2) << etam << "m";
+ log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
+ log_stream << " ETA: " << utils::ClockPrint(eta);
log_stream << std::flush;
}
std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
std::shared_ptr<DenseMatrix> input =
@@ -297,17 +309,20 @@
input->l2NormRow(norms);
std::vector<int32_t> idx(input->size(0), 0);
std::iota(idx.begin(), idx.end(), 0);
auto eosid = dict_->getId(Dictionary::EOS);
std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
+ if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
+ return false;
+ }
return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
});
idx.erase(idx.begin() + cutoff, idx.end());
return idx;
}
-void FastText::quantize(const Args& qargs) {
+void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
if (args_->model != model_name::sup) {
throw std::invalid_argument(
"For now we only support quantization of supervised models");
}
args_->input = qargs.input;
@@ -335,22 +350,20 @@
args_->lr = qargs.lr;
args_->thread = qargs.thread;
args_->verbose = qargs.verbose;
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
- startThreads();
+ startThreads(callback);
}
}
-
input_ = std::make_shared<QuantMatrix>(
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
if (args_->qout) {
output_ = std::make_shared<QuantMatrix>(
std::move(*(output.get())), 2, qargs.qnorm);
}
-
quant_ = true;
auto loss = createLoss(output_);
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
}
@@ -406,11 +419,11 @@
}
}
std::tuple<int64_t, double, double>
FastText::test(std::istream& in, int32_t k, real threshold) {
- Meter meter;
+ Meter meter(false);
test(in, k, threshold, meter);
return std::tuple<int64_t, double, double>(
meter.nexamples(), meter.precision(), meter.recall());
}
@@ -418,10 +431,13 @@
void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
const {
std::vector<int32_t> line;
std::vector<int32_t> labels;
Predictions predictions;
+ Model::State state(args_->dim, dict_->nlabels(), 0);
+ in.clear();
+ in.seekg(0, std::ios_base::beg);
while (in.peek() != EOF) {
line.clear();
labels.clear();
dict_->getLine(in, line, labels);
@@ -519,20 +535,10 @@
result.push_back(std::make_pair(substrings[i], std::move(vec)));
}
return result;
}
-// deprecated. use getNgramVectors instead
-void FastText::ngramVectors(std::string word) {
- std::vector<std::pair<std::string, Vector>> ngramVectors =
- getNgramVectors(word);
-
- for (const auto& ngramVector : ngramVectors) {
- std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
- }
-}
-
void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
Vector vec(args_->dim);
wordVectors.zero();
for (int32_t i = 0; i < dict_->nwords(); i++) {
std::string word = dict_->getWord(i);
@@ -596,21 +602,10 @@
std::sort_heap(heap.begin(), heap.end(), comparePairs);
return heap;
}
-// depracted. use getNN instead
-void FastText::findNN(
- const DenseMatrix& wordVectors,
- const Vector& query,
- int32_t k,
- const std::set<std::string>& banSet,
- std::vector<std::pair<real, std::string>>& results) {
- results.clear();
- results = getNN(wordVectors, query, k, banSet);
-}
-
std::vector<std::pair<real, std::string>> FastText::getAnalogies(
int32_t k,
const std::string& wordA,
const std::string& wordB,
const std::string& wordC) {
@@ -628,56 +623,56 @@
lazyComputeWordVectors();
assert(wordVectors_);
return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
}
-// depreacted, use getAnalogies instead
-void FastText::analogies(int32_t k) {
- std::string prompt("Query triplet (A - B + C)? ");
- std::string wordA, wordB, wordC;
- std::cout << prompt;
- while (true) {
- std::cin >> wordA;
- std::cin >> wordB;
- std::cin >> wordC;
- auto results = getAnalogies(k, wordA, wordB, wordC);
-
- for (auto& pair : results) {
- std::cout << pair.second << " " << pair.first << std::endl;
- }
- std::cout << prompt;
- }
+bool FastText::keepTraining(const int64_t ntokens) const {
+ return tokenCount_ < args_->epoch * ntokens && !trainException_;
}
-void FastText::trainThread(int32_t threadId) {
+void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
std::ifstream ifs(args_->input);
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
- Model::State state(args_->dim, output_->size(0), threadId);
+ Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
const int64_t ntokens = dict_->ntokens();
int64_t localTokenCount = 0;
std::vector<int32_t> line, labels;
- while (tokenCount_ < args_->epoch * ntokens) {
- real progress = real(tokenCount_) / (args_->epoch * ntokens);
- real lr = args_->lr * (1.0 - progress);
- if (args_->model == model_name::sup) {
- localTokenCount += dict_->getLine(ifs, line, labels);
- supervised(state, lr, line, labels);
- } else if (args_->model == model_name::cbow) {
- localTokenCount += dict_->getLine(ifs, line, state.rng);
- cbow(state, lr, line);
- } else if (args_->model == model_name::sg) {
- localTokenCount += dict_->getLine(ifs, line, state.rng);
- skipgram(state, lr, line);
+ uint64_t callbackCounter = 0;
+ try {
+ while (keepTraining(ntokens)) {
+ real progress = real(tokenCount_) / (args_->epoch * ntokens);
+ if (callback && ((callbackCounter++ % 64) == 0)) {
+ double wst;
+ double lr;
+ int64_t eta;
+ std::tie<double, double, int64_t>(wst, lr, eta) =
+ progressInfo(progress);
+ callback(progress, loss_, wst, lr, eta);
+ }
+ real lr = args_->lr * (1.0 - progress);
+ if (args_->model == model_name::sup) {
+ localTokenCount += dict_->getLine(ifs, line, labels);
+ supervised(state, lr, line, labels);
+ } else if (args_->model == model_name::cbow) {
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
+ cbow(state, lr, line);
+ } else if (args_->model == model_name::sg) {
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
+ skipgram(state, lr, line);
+ }
+ if (localTokenCount > args_->lrUpdateRate) {
+ tokenCount_ += localTokenCount;
+ localTokenCount = 0;
+ if (threadId == 0 && args_->verbose > 1) {
+ loss_ = state.getLoss();
+ }
+ }
}
- if (localTokenCount > args_->lrUpdateRate) {
- tokenCount_ += localTokenCount;
- localTokenCount = 0;
- if (threadId == 0 && args_->verbose > 1)
- loss_ = state.getLoss();
- }
+ } catch (DenseMatrix::EncounteredNaNError&) {
+ trainException_ = std::current_exception();
}
if (threadId == 0)
loss_ = state.getLoss();
ifs.close();
}
@@ -711,11 +706,11 @@
dict_->threshold(1, 0);
dict_->init();
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
dict_->nwords() + args_->bucket, args_->dim);
- input->uniform(1.0 / args_->dim);
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
for (size_t i = 0; i < n; i++) {
int32_t idx = dict_->getId(words[i]);
if (idx < 0 || idx >= dict_->nwords()) {
continue;
@@ -725,18 +720,14 @@
}
}
return input;
}
-void FastText::loadVectors(const std::string& filename) {
- input_ = getInputMatrixFromFile(filename);
-}
-
std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
dict_->nwords() + args_->bucket, args_->dim);
- input->uniform(1.0 / args_->dim);
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
return input;
}
std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
@@ -747,11 +738,11 @@
output->zero();
return output;
}
-void FastText::train(const Args& args) {
+void FastText::train(const Args& args, const TrainCallback& callback) {
args_ = std::make_shared<Args>(args);
dict_ = std::make_shared<Dictionary>(args_);
if (args_->input == "-") {
// manage expectations
throw std::invalid_argument("Cannot use stdin for training!");
@@ -768,35 +759,55 @@
input_ = getInputMatrixFromFile(args_->pretrainedVectors);
} else {
input_ = createRandomMatrix();
}
output_ = createTrainOutputMatrix();
+ quant_ = false;
auto loss = createLoss(output_);
bool normalizeGradient = (args_->model == model_name::sup);
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
- startThreads();
+ startThreads(callback);
}
-void FastText::startThreads() {
+void FastText::abort() {
+ try {
+ throw AbortError();
+ } catch (AbortError&) {
+ trainException_ = std::current_exception();
+ }
+}
+
+void FastText::startThreads(const TrainCallback& callback) {
start_ = std::chrono::steady_clock::now();
tokenCount_ = 0;
loss_ = -1;
+ trainException_ = nullptr;
std::vector<std::thread> threads;
- for (int32_t i = 0; i < args_->thread; i++) {
- threads.push_back(std::thread([=]() { trainThread(i); }));
+ if (args_->thread > 1) {
+ for (int32_t i = 0; i < args_->thread; i++) {
+ threads.push_back(std::thread([=]() { trainThread(i, callback); }));
+ }
+ } else {
+ // webassembly can't instantiate `std::thread`
+ trainThread(0, callback);
}
const int64_t ntokens = dict_->ntokens();
// Same condition as trainThread
- while (tokenCount_ < args_->epoch * ntokens) {
+ while (keepTraining(ntokens)) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (loss_ >= 0 && args_->verbose > 1) {
real progress = real(tokenCount_) / (args_->epoch * ntokens);
std::cerr << "\r";
printInfo(progress, loss_, std::cerr);
}
}
- for (int32_t i = 0; i < args_->thread; i++) {
+ for (int32_t i = 0; i < threads.size(); i++) {
threads[i].join();
+ }
+ if (trainException_) {
+ std::exception_ptr exception = trainException_;
+ trainException_ = nullptr;
+ std::rethrow_exception(exception);
}
if (args_->verbose > 0) {
std::cerr << "\r";
printInfo(1.0, loss_, std::cerr);
std::cerr << std::endl;