vendor/fastText/src/fasttext.h in fasttext-0.1.2 vs vendor/fastText/src/fasttext.h in fasttext-0.1.3
- old
+ new
@@ -10,10 +10,11 @@
#include <time.h>
#include <atomic>
#include <chrono>
+#include <functional>
#include <iostream>
#include <memory>
#include <queue>
#include <set>
#include <tuple>
@@ -29,28 +30,33 @@
#include "vector.h"
namespace fasttext {
class FastText {
+ public:
+ using TrainCallback =
+ std::function<void(float, float, double, double, int64_t)>;
+
protected:
std::shared_ptr<Args> args_;
std::shared_ptr<Dictionary> dict_;
-
std::shared_ptr<Matrix> input_;
std::shared_ptr<Matrix> output_;
-
std::shared_ptr<Model> model_;
-
std::atomic<int64_t> tokenCount_{};
std::atomic<real> loss_{};
-
std::chrono::steady_clock::time_point start_;
+ bool quant_;
+ int32_t version;
+ std::unique_ptr<DenseMatrix> wordVectors_;
+ std::exception_ptr trainException_;
+
void signModel(std::ostream&);
bool checkModel(std::istream&);
- void startThreads();
+ void startThreads(const TrainCallback& callback = {});
void addInputVector(Vector&, int32_t) const;
- void trainThread(int32_t);
+ void trainThread(int32_t, const TrainCallback& callback);
std::vector<std::pair<real, std::string>> getNN(
const DenseMatrix& wordVectors,
const Vector& queryVec,
int32_t k,
const std::set<std::string>& banSet);
@@ -66,22 +72,25 @@
real lr,
const std::vector<int32_t>& line,
const std::vector<int32_t>& labels);
void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
+ std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
+ void precomputeWordVectors(DenseMatrix& wordVectors);
+ bool keepTraining(const int64_t ntokens) const;
+ void buildModel();
+ std::tuple<int64_t, double, double> progressInfo(real progress);
- bool quant_;
- int32_t version;
- std::unique_ptr<DenseMatrix> wordVectors_;
-
public:
FastText();
int32_t getWordId(const std::string& word) const;
int32_t getSubwordId(const std::string& subword) const;
+ int32_t getLabelId(const std::string& label) const;
+
void getWordVector(Vector& vec, const std::string& word) const;
void getSubwordVector(Vector& vec, const std::string& subword) const;
inline void getInputVector(Vector& vec, int32_t ind) {
@@ -93,10 +102,14 @@
std::shared_ptr<const Dictionary> getDictionary() const;
std::shared_ptr<const DenseMatrix> getInputMatrix() const;
+ void setMatrices(
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
+ const std::shared_ptr<DenseMatrix>& outputMatrix);
+
std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
void saveVectors(const std::string& filename);
void saveModel(const std::string& filename);
@@ -107,11 +120,11 @@
void loadModel(const std::string& filename);
void getSentenceVector(std::istream& in, Vector& vec);
- void quantize(const Args& qargs);
+ void quantize(const Args& qargs, const TrainCallback& callback = {});
std::tuple<int64_t, double, double>
test(std::istream& in, int32_t k, real threshold = 0.0);
void test(std::istream& in, int32_t k, real threshold, Meter& meter) const;
@@ -139,53 +152,19 @@
int32_t k,
const std::string& wordA,
const std::string& wordB,
const std::string& wordC);
- void train(const Args& args);
+ void train(const Args& args, const TrainCallback& callback = {});
+ void abort();
+
int getDimension() const;
bool isQuant() const;
- FASTTEXT_DEPRECATED("loadVectors is being deprecated.")
- void loadVectors(const std::string& filename);
-
- FASTTEXT_DEPRECATED(
- "getVector is being deprecated and replaced by getWordVector.")
- void getVector(Vector& vec, const std::string& word) const;
-
- FASTTEXT_DEPRECATED(
- "ngramVectors is being deprecated and replaced by getNgramVectors.")
- void ngramVectors(std::string word);
-
- FASTTEXT_DEPRECATED(
- "analogies is being deprecated and replaced by getAnalogies.")
- void analogies(int32_t k);
-
- FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
- std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
-
- FASTTEXT_DEPRECATED(
- "saveVectors is being deprecated, please use the other signature.")
- void saveVectors();
-
- FASTTEXT_DEPRECATED(
- "saveOutput is being deprecated, please use the other signature.")
- void saveOutput();
-
- FASTTEXT_DEPRECATED(
- "saveModel is being deprecated, please use the other signature.")
- void saveModel();
-
- FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
- void precomputeWordVectors(DenseMatrix& wordVectors);
-
- FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
- void findNN(
- const DenseMatrix& wordVectors,
- const Vector& query,
- int32_t k,
- const std::set<std::string>& banSet,
- std::vector<std::pair<real, std::string>>& results);
+ class AbortError : public std::runtime_error {
+ public:
+ AbortError() : std::runtime_error("Aborted.") {}
+ };
};
} // namespace fasttext