#pragma once #include "LDAModel.hpp" #include "LDA.h" namespace tomoto { template struct DocumentDTM : public DocumentLDA<_tw> { using BaseDocument = DocumentLDA<_tw>; using DocumentLDA<_tw>::DocumentLDA; size_t timepoint = 0; ShareableVector eta; sample::AliasMethod<> aliasTable; DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint); DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, timepoint); }; class IDTModel : public ILDAModel { public: using DefaultDocType = DocumentDTM; static IDTModel* create(TermWeight _weight, size_t _K = 1, size_t _T = 1, Float _alphaVar = 1.0, Float _etaVar = 1.0, Float _phiVar = 1.0, Float _shapeA = 0.03, Float _shapeB = 0.1, Float _shapeC = 0.55, Float _etaRegL2 = 0, size_t seed = std::random_device{}(), bool scalarRng = false); virtual size_t addDoc(const std::vector& words, size_t timepoint) = 0; virtual std::unique_ptr makeDoc(const std::vector& words, size_t timepoint) const = 0; virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, size_t timepoint) = 0; virtual std::unique_ptr makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, size_t timepoint) const = 0; virtual size_t addDoc(const std::string& rawStr, const std::vector& words, const std::vector& pos, const std::vector& len, size_t timepoint) = 0; virtual std::unique_ptr makeDoc(const std::string& rawStr, const std::vector& words, const std::vector& pos, const std::vector& len, size_t timepoint) const = 0; virtual size_t getT() const = 0; virtual std::vector getNumDocsByT() const = 0; virtual Float getAlphaVar() const = 0; virtual Float getEtaVar() const = 0; virtual Float getPhiVar() const = 0; virtual Float getShapeA() const = 0; virtual Float getShapeB() const = 0; virtual Float getShapeC() const = 0; virtual void setShapeA(Float a) = 0; virtual void setShapeB(Float a) = 0; virtual void setShapeC(Float a) = 0; virtual Float getAlpha(size_t k, size_t t) const = 0; virtual std::vector getPhi(size_t k, size_t t) const = 0; }; }