#pragma once #include #include #include "../Utils/Utils.hpp" #include "../Utils/Dictionary.h" #include "../Utils/tvector.hpp" #include "../Utils/ThreadPool.hpp" #include "../Utils/serializer.hpp" #include "../Utils/exception.h" #include namespace tomoto { using RandGen = Eigen::Rand::P8_mt19937_64; using ScalarRandGen = Eigen::Rand::UniversalRandomEngine; class DocumentBase { public: Float weight = 1; tvector words; // word id of each word std::vector wOrder; // original word order (optional) std::string docUid; std::string rawStr; std::vector origWordPos; std::vector origWordLen; DocumentBase(Float _weight = 1) : weight(_weight) {} virtual ~DocumentBase() {} DEFINE_SERIALIZER_WITH_VERSION(0, serializer::to_key("Docu"), weight, words, wOrder); DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, weight, words, wOrder, rawStr, origWordPos, origWordLen, docUid ); }; enum class ParallelScheme { default_, none, copy_merge, partition, size }; inline const char* toString(ParallelScheme ps) { switch (ps) { case ParallelScheme::default_: return "default"; case ParallelScheme::none: return "none"; case ParallelScheme::copy_merge: return "copy_merge"; case ParallelScheme::partition: return "partition"; default: return "unknown"; } } class RawDocTokenizer { public: using Token = std::tuple; using Factory = std::function; private: std::function fnNext; public: class Iterator { RawDocTokenizer* p = nullptr; bool end = true; std::tuple value; public: Iterator() { } Iterator(RawDocTokenizer* _p) : p{ _p }, end{ false } { operator++(); } std::tuple& operator*() { return value; } Iterator& operator++() { auto v = p->fnNext(); if (std::get<3>(v)) { end = true; } else { value = std::make_tuple(std::get<0>(v), std::get<1>(v), std::get<2>(v)); } return *this; } bool operator==(const Iterator& o) const { return o.end && end; } bool operator!=(const Iterator& o) const { return !operator==(o); } }; template RawDocTokenizer(_Fn&& fn) : fnNext{ std::forward<_Fn>(fn) } { } Iterator begin() { return Iterator{ this }; } Iterator end() { return Iterator{}; } }; class ITopicModel { public: virtual void saveModel(std::ostream& writer, bool fullModel, const std::vector* extra_data = nullptr) const = 0; virtual void loadModel(std::istream& reader, std::vector* extra_data = nullptr) = 0; virtual const DocumentBase* getDoc(size_t docId) const = 0; virtual void updateVocab(const std::vector& words) = 0; virtual double getLLPerWord() const = 0; virtual double getPerplexity() const = 0; virtual uint64_t getV() const = 0; virtual uint64_t getN() const = 0; virtual size_t getNumDocs() const = 0; virtual const Dictionary& getVocabDict() const = 0; virtual const std::vector& getVocabCf() const = 0; virtual const std::vector& getVocabDf() const = 0; virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_) = 0; virtual size_t getGlobalStep() const = 0; virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) = 0; virtual size_t getK() const = 0; virtual std::vector getWidsByTopic(size_t tid) const = 0; virtual std::vector> getWordsByTopicSorted(size_t tid, size_t topN) const = 0; virtual std::vector> getWordsByDocSorted(const DocumentBase* doc, size_t topN) const = 0; virtual std::vector getTopicsByDoc(const DocumentBase* doc) const = 0; virtual std::vector> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const = 0; virtual std::vector infer(const std::vector& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const = 0; virtual ~ITopicModel() {} }; template static std::vector> extractTopN(const std::vector<_TyValue>& vec, size_t topN) { typedef std::pair<_TyKey, _TyValue> pair_t; std::vector ret; _TyKey k = 0; for (auto& t : vec) { ret.emplace_back(std::make_pair(k++, t)); } std::sort(ret.begin(), ret.end(), [](const pair_t& a, const pair_t& b) { return a.second > b.second; }); if (topN < ret.size()) ret.erase(ret.begin() + topN, ret.end()); return ret; } namespace flags { enum { continuous_doc_data = 1 << 0, shared_state = 1 << 1, partitioned_multisampling = 1 << 2, end_flag_of_TopicModel = 1 << 3, }; } template class TopicModel : public _Interface { friend class Document; public: using DocType = _DocType; protected: _RandGen rg; std::vector<_RandGen> localRG; std::vector words; std::vector wOffsetByDoc; std::vector docs; std::vector vocabCf; std::vector vocabDf; size_t globalStep = 0; _ModelState globalState, tState; Dictionary dict; uint64_t realV = 0; // vocab size after removing stopwords uint64_t realN = 0; // total word size after removing stopwords size_t maxThreads[(size_t)ParallelScheme::size] = { 0, }; size_t minWordCf = 0, minWordDf = 0, removeTopN = 0; std::unique_ptr cachedPool; void _saveModel(std::ostream& writer, bool fullModel, const std::vector* extra_data) const { serializer::writeMany(writer, serializer::to_keyz(static_cast(this)->TMID), serializer::to_keyz(static_cast(this)->TWID)); serializer::writeTaggedMany(writer, 0x00010001, serializer::to_keyz("dict"), dict, serializer::to_keyz("vocabCf"), vocabCf, serializer::to_keyz("vocabDf"), vocabDf, serializer::to_keyz("realV"), realV, serializer::to_keyz("globalStep"), globalStep, serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector(0)); serializer::writeMany(writer, *static_cast(this)); globalState.serializerWrite(writer); if (fullModel) { serializer::writeMany(writer, docs); } else { serializer::writeMany(writer, std::vector{}); } } void _loadModel(std::istream& reader, std::vector* extra_data) { auto start_pos = reader.tellg(); try { std::vector extra; serializer::readMany(reader, serializer::to_keyz(static_cast<_Derived*>(this)->TMID), serializer::to_keyz(static_cast<_Derived*>(this)->TWID)); serializer::readTaggedMany(reader, 0x00010001, serializer::to_keyz("dict"), dict, serializer::to_keyz("vocabCf"), vocabCf, serializer::to_keyz("vocabDf"), vocabDf, serializer::to_keyz("realV"), realV, serializer::to_keyz("globalStep"), globalStep, serializer::to_keyz("extra"), extra); if (extra_data) *extra_data = std::move(extra); } catch (const std::ios_base::failure&) { reader.seekg(start_pos); serializer::readMany(reader, serializer::to_key(static_cast<_Derived*>(this)->TMID), serializer::to_key(static_cast<_Derived*>(this)->TWID), dict, vocabCf, realV); } serializer::readMany(reader, *static_cast<_Derived*>(this)); globalState.serializerRead(reader); serializer::readMany(reader, docs); realN = countRealN(); } template typename std::enable_if::type>::type >::value, size_t>::type _addDoc(_DocTy&& doc) { if (doc.words.empty()) return -1; size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end()); if (vocabCf.size() <= maxWid) { vocabCf.resize(maxWid + 1); vocabDf.resize(maxWid + 1); } for (auto w : doc.words) ++vocabCf[w]; std::unordered_set uniq{ doc.words.begin(), doc.words.end() }; for (auto w : uniq) ++vocabDf[w]; docs.emplace_back(std::forward<_DocTy>(doc)); return docs.size() - 1; } template DocType _makeDoc(const std::vector& words, Float weight = 1) { DocType doc{ weight }; for (auto& w : words) { Vid id; if (_const) { id = dict.toWid(w); if (id == (Vid)-1) continue; } else { id = dict.add(w); } doc.words.emplace_back(id); } return doc; } DocType _makeRawDoc(const std::string& rawStr, const std::vector& words, const std::vector& pos, const std::vector& len, Float weight = 1) const { DocType doc{ weight }; doc.rawStr = rawStr; for (auto& w : words) doc.words.emplace_back(w); doc.origWordPos = pos; doc.origWordLen = len; return doc; } template DocType _makeRawDoc(const std::string& rawStr, _FnTokenizer&& tokenizer, Float weight = 1) { DocType doc{ weight }; doc.rawStr = rawStr; for (auto& p : tokenizer(doc.rawStr)) { Vid wid; if (_const) { wid = dict.toWid(std::get<0>(p)); if (wid == (Vid)-1) continue; } else { wid = dict.add(std::get<0>(p)); } auto pos = std::get<1>(p); auto len = std::get<2>(p); doc.words.emplace_back(wid); doc.origWordPos.emplace_back(pos); doc.origWordLen.emplace_back(len); } return doc; } const DocType& _getDoc(size_t docId) const { return docs[docId]; } void updateWeakArray() { wOffsetByDoc.emplace_back(0); for (auto& doc : docs) { wOffsetByDoc.emplace_back(wOffsetByDoc.back() + doc.words.size()); } auto tx = [](_DocType& doc) { return &doc.words; }; tvector::trade(words, makeTransformIter(docs.begin(), tx), makeTransformIter(docs.end(), tx)); } size_t countRealN() const { size_t n = 0; for (auto& doc : docs) { for (auto& w : doc.words) { if (w < realV) ++n; } } return n; } void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN) { if (minWordCnt <= 1 && minWordDf <= 1 && removeTopN == 0) realV = dict.size(); this->minWordCf = minWordCnt; this->minWordDf = minWordDf; this->removeTopN = removeTopN; std::vector> vocabCfDf; for (size_t i = 0; i < vocabCf.size(); ++i) { vocabCfDf.emplace_back(vocabCf[i], vocabDf[i]); } std::vector order; sortAndWriteOrder(vocabCfDf, order, removeTopN, [&](const std::pair& a, const std::pair& b) { if (a.first < minWordCnt || a.second < minWordDf) { if (b.first < minWordCnt || b.second < minWordDf) { return a > b; } return false; } if (b.first < minWordCnt || b.second < minWordDf) { return true; } return a > b; }); realV = std::find_if(vocabCfDf.begin(), vocabCfDf.end() - std::min(removeTopN, vocabCfDf.size()), [&](const std::pair& a) { return a.first < minWordCnt || a.second < minWordDf; }) - vocabCfDf.begin(); for (size_t i = 0; i < vocabCfDf.size(); ++i) { vocabCf[i] = vocabCfDf[i].first; vocabDf[i] = vocabCfDf[i].second; } dict.reorder(order); realN = 0; for (auto& doc : docs) { for (auto& w : doc.words) { w = order[w]; if (w < realV) ++realN; } } } int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs) { throw e; } public: TopicModel(size_t _rg) : rg(_rg) { } size_t getNumDocs() const override { return docs.size(); } uint64_t getN() const override { return realN; } uint64_t getV() const override { return realV; } void updateVocab(const std::vector& words) override { if(dict.size()) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "updateVocab after addDoc"); for(auto& w : words) dict.add(w); } void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override { maxThreads[(size_t)ParallelScheme::default_] = -1; maxThreads[(size_t)ParallelScheme::none] = -1; maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads(); maxThreads[(size_t)ParallelScheme::partition] = static_cast<_Derived*>(this)->template estimateMaxThreads(); } static ParallelScheme getRealScheme(ParallelScheme ps) { switch (ps) { case ParallelScheme::default_: if ((_Flags & flags::partitioned_multisampling)) return ParallelScheme::partition; if ((_Flags & flags::shared_state)) return ParallelScheme::none; return ParallelScheme::copy_merge; case ParallelScheme::copy_merge: if ((_Flags & flags::shared_state)) THROW_ERROR_WITH_INFO(exception::InvalidArgument, std::string{ "This model doesn't provide ParallelScheme::" } + toString(ps)); break; case ParallelScheme::partition: if (!(_Flags & flags::partitioned_multisampling)) THROW_ERROR_WITH_INFO(exception::InvalidArgument, std::string{ "This model doesn't provide ParallelScheme::" } + toString(ps)); break; } return ps; } int train(size_t iteration, size_t numWorkers, ParallelScheme ps) override { if (!numWorkers) numWorkers = std::thread::hardware_concurrency(); ps = getRealScheme(ps); numWorkers = std::min(numWorkers, maxThreads[(size_t)ps]); if (numWorkers == 1 || (_Flags & flags::shared_state)) ps = ParallelScheme::none; if (!cachedPool || cachedPool->getNumWorkers() != numWorkers) { cachedPool = make_unique(numWorkers); } std::vector<_ModelState> localData; while(localRG.size() < numWorkers) { localRG.emplace_back(rg()); } for (size_t i = 0; i < numWorkers; ++i) { if(ps == ParallelScheme::copy_merge) localData.emplace_back(static_cast<_Derived*>(this)->globalState); } if (ps == ParallelScheme::partition) { localData.resize(numWorkers); static_cast<_Derived*>(this)->updatePartition(*cachedPool, globalState, localData.data(), docs.begin(), docs.end(), static_cast<_Derived*>(this)->eddTrain); } auto state = ps == ParallelScheme::none ? &globalState : localData.data(); for (size_t i = 0; i < iteration; ++i) { while (1) { try { switch (ps) { case ParallelScheme::none: static_cast<_Derived*>(this)->template trainOne( *cachedPool, state, localRG.data()); break; case ParallelScheme::copy_merge: static_cast<_Derived*>(this)->template trainOne( *cachedPool, state, localRG.data()); break; case ParallelScheme::partition: static_cast<_Derived*>(this)->template trainOne( *cachedPool, state, localRG.data()); break; } break; } catch (const exception::TrainingError& e) { std::cerr << e.what() << std::endl; int ret = static_cast<_Derived*>(this)->restoreFromTrainingError( e, *cachedPool, state, localRG.data()); if(ret < 0) return ret; } } ++globalStep; } return 0; } double getLLPerWord() const override { return words.empty() ? 0 : static_cast(this)->getLL() / realN; } double getPerplexity() const override { return exp(-getLLPerWord()); } size_t getK() const override { return 0; } std::vector getWidsByTopic(size_t tid) const override { return static_cast(this)->_getWidsByTopic(tid); } std::vector> getWidsByTopicSorted(size_t tid, size_t topN) const { return extractTopN(static_cast(this)->_getWidsByTopic(tid), topN); } std::vector> vid2String(const std::vector>& vids) const { std::vector> ret(vids.size()); for (size_t i = 0; i < vids.size(); ++i) { ret[i] = std::make_pair(dict.toWord(vids[i].first), vids[i].second); } return ret; } std::vector> getWordsByTopicSorted(size_t tid, size_t topN) const override { return vid2String(getWidsByTopicSorted(tid, topN)); } std::vector> getWidsByDocSorted(const DocumentBase* doc, size_t topN) const { std::vector cnt(dict.size()); for (auto w : doc->words) cnt[w] += 1; for (auto& c : cnt) c /= doc->words.size(); return extractTopN(cnt, topN); } std::vector> getWordsByDocSorted(const DocumentBase* doc, size_t topN) const override { return vid2String(getWidsByDocSorted(doc, topN)); } std::vector infer(const std::vector& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const override { if (!numWorkers) numWorkers = std::thread::hardware_concurrency(); ps = getRealScheme(ps); if (numWorkers == 1) ps = ParallelScheme::none; auto tx = [](DocumentBase* p)->DocType& { return *static_cast(p); }; auto b = makeTransformIter(docs.begin(), tx), e = makeTransformIter(docs.end(), tx); if (together) { switch (ps) { case ParallelScheme::none: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); case ParallelScheme::copy_merge: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); case ParallelScheme::partition: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); } } else { switch (ps) { case ParallelScheme::none: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); case ParallelScheme::copy_merge: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); case ParallelScheme::partition: return static_cast(this)->template _infer(b, e, maxIter, tolerance, numWorkers); } } THROW_ERROR_WITH_INFO(exception::InvalidArgument, "invalid ParallelScheme"); } std::vector getTopicsByDoc(const DocumentBase* doc) const override { return static_cast(this)->getTopicsByDoc(*static_cast(doc)); } std::vector> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override { return extractTopN(getTopicsByDoc(doc), topN); } const DocumentBase* getDoc(size_t docId) const override { return &_getDoc(docId); } size_t getGlobalStep() const override { return globalStep; } const Dictionary& getVocabDict() const override { return dict; } const std::vector& getVocabCf() const override { return vocabCf; } const std::vector& getVocabDf() const override { return vocabDf; } void saveModel(std::ostream& writer, bool fullModel, const std::vector* extra_data) const override { static_cast(this)->_saveModel(writer, fullModel, extra_data); } void loadModel(std::istream& reader, std::vector* extra_data) override { static_cast<_Derived*>(this)->_loadModel(reader, extra_data); static_cast<_Derived*>(this)->prepare(false); } }; }