vendor/fastText/src/fasttext.h in fasttext-0.1.2 vs vendor/fastText/src/fasttext.h in fasttext-0.1.3

- old
+ new

@@ -10,10 +10,11 @@ #include <time.h> #include <atomic> #include <chrono> +#include <functional> #include <iostream> #include <memory> #include <queue> #include <set> #include <tuple> @@ -29,28 +30,33 @@ #include "vector.h" namespace fasttext { class FastText { + public: + using TrainCallback = + std::function<void(float, float, double, double, int64_t)>; + protected: std::shared_ptr<Args> args_; std::shared_ptr<Dictionary> dict_; - std::shared_ptr<Matrix> input_; std::shared_ptr<Matrix> output_; - std::shared_ptr<Model> model_; - std::atomic<int64_t> tokenCount_{}; std::atomic<real> loss_{}; - std::chrono::steady_clock::time_point start_; + bool quant_; + int32_t version; + std::unique_ptr<DenseMatrix> wordVectors_; + std::exception_ptr trainException_; + void signModel(std::ostream&); bool checkModel(std::istream&); - void startThreads(); + void startThreads(const TrainCallback& callback = {}); void addInputVector(Vector&, int32_t) const; - void trainThread(int32_t); + void trainThread(int32_t, const TrainCallback& callback); std::vector<std::pair<real, std::string>> getNN( const DenseMatrix& wordVectors, const Vector& queryVec, int32_t k, const std::set<std::string>& banSet); @@ -66,22 +72,25 @@ real lr, const std::vector<int32_t>& line, const std::vector<int32_t>& labels); void cbow(Model::State& state, real lr, const std::vector<int32_t>& line); void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line); + std::vector<int32_t> selectEmbeddings(int32_t cutoff) const; + void precomputeWordVectors(DenseMatrix& wordVectors); + bool keepTraining(const int64_t ntokens) const; + void buildModel(); + std::tuple<int64_t, double, double> progressInfo(real progress); - bool quant_; - int32_t version; - std::unique_ptr<DenseMatrix> wordVectors_; - public: FastText(); int32_t getWordId(const std::string& word) const; int32_t getSubwordId(const std::string& subword) const; + int32_t getLabelId(const std::string& label) const; + void getWordVector(Vector& vec, const std::string& word) const; void getSubwordVector(Vector& vec, const std::string& subword) const; inline void getInputVector(Vector& vec, int32_t ind) { @@ -93,10 +102,14 @@ std::shared_ptr<const Dictionary> getDictionary() const; std::shared_ptr<const DenseMatrix> getInputMatrix() const; + void setMatrices( + const std::shared_ptr<DenseMatrix>& inputMatrix, + const std::shared_ptr<DenseMatrix>& outputMatrix); + std::shared_ptr<const DenseMatrix> getOutputMatrix() const; void saveVectors(const std::string& filename); void saveModel(const std::string& filename); @@ -107,11 +120,11 @@ void loadModel(const std::string& filename); void getSentenceVector(std::istream& in, Vector& vec); - void quantize(const Args& qargs); + void quantize(const Args& qargs, const TrainCallback& callback = {}); std::tuple<int64_t, double, double> test(std::istream& in, int32_t k, real threshold = 0.0); void test(std::istream& in, int32_t k, real threshold, Meter& meter) const; @@ -139,53 +152,19 @@ int32_t k, const std::string& wordA, const std::string& wordB, const std::string& wordC); - void train(const Args& args); + void train(const Args& args, const TrainCallback& callback = {}); + void abort(); + int getDimension() const; bool isQuant() const; - FASTTEXT_DEPRECATED("loadVectors is being deprecated.") - void loadVectors(const std::string& filename); - - FASTTEXT_DEPRECATED( - "getVector is being deprecated and replaced by getWordVector.") - void getVector(Vector& vec, const std::string& word) const; - - FASTTEXT_DEPRECATED( - "ngramVectors is being deprecated and replaced by getNgramVectors.") - void ngramVectors(std::string word); - - FASTTEXT_DEPRECATED( - "analogies is being deprecated and replaced by getAnalogies.") - void analogies(int32_t k); - - FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.") - std::vector<int32_t> selectEmbeddings(int32_t cutoff) const; - - FASTTEXT_DEPRECATED( - "saveVectors is being deprecated, please use the other signature.") - void saveVectors(); - - FASTTEXT_DEPRECATED( - "saveOutput is being deprecated, please use the other signature.") - void saveOutput(); - - FASTTEXT_DEPRECATED( - "saveModel is being deprecated, please use the other signature.") - void saveModel(); - - FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.") - void precomputeWordVectors(DenseMatrix& wordVectors); - - FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.") - void findNN( - const DenseMatrix& wordVectors, - const Vector& query, - int32_t k, - const std::set<std::string>& banSet, - std::vector<std::pair<real, std::string>>& results); + class AbortError : public std::runtime_error { + public: + AbortError() : std::runtime_error("Aborted.") {} + }; }; } // namespace fasttext