#pragma once #include "LDA.h" namespace tomoto { template struct DocumentMGLDA : public DocumentLDA<_tw> { using BaseDocument = DocumentLDA<_tw>; using DocumentLDA<_tw>::DocumentLDA; using WeightType = typename DocumentLDA<_tw>::WeightType; std::vector sents; // sentence id of each word (const) std::vector numBySent; // number of words in the sentence (const) //std::vector Zs; // gl./loc. and topic assignment std::vector Vs; // window assignment WeightType numGl = 0; // number of words assigned as gl. //std::vector numByTopic; // len = K + KL Eigen::Matrix numBySentWin; // len = S * T Eigen::Matrix numByWinL; // number of words assigned as loc. in the window (len = S + T - 1) Eigen::Matrix numByWin; // number of words in the window (len = S + T - 1) Eigen::Matrix numByWinTopicL; // number of words in the loc. topic in the window (len = KL * (S + T - 1)) DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL); DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL); template void update(WeightType* ptr, const _TopicModel& mdl); }; class IMGLDAModel : public ILDAModel { public: using DefaultDocType = DocumentMGLDA; static IMGLDAModel* create(TermWeight _weight, size_t _KG = 1, size_t _KL = 1, size_t _T = 3, Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1, Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t seed = std::random_device{}(), bool scalarRng = false); virtual size_t addDoc(const std::vector& words, const std::string& delimiter) = 0; virtual std::unique_ptr makeDoc(const std::vector& words, const std::string& delimiter) const = 0; virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, const std::string& delimiter) = 0; virtual std::unique_ptr makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, const std::string& delimiter) const = 0; virtual size_t addDoc(const std::string& rawStr, const std::vector& words, const std::vector& pos, const std::vector& len, const std::string& delimiter) = 0; virtual std::unique_ptr makeDoc(const std::string& rawStr, const std::vector& words, const std::vector& pos, const std::vector& len, const std::string& delimiter) const = 0; virtual size_t getKL() const = 0; virtual size_t getT() const = 0; virtual Float getAlphaL() const = 0; virtual Float getEtaL() const = 0; virtual Float getGamma() const = 0; virtual Float getAlphaM() const = 0; virtual Float getAlphaML() const = 0; }; }