vendor/tomotopy/src/TopicModel/SLDA.h in tomoto-0.1.4 vs vendor/tomotopy/src/TopicModel/SLDA.h in tomoto-0.2.0

- old
+ new

@@ -7,33 +7,45 @@ struct DocumentSLDA : public DocumentLDA<_tw> { using BaseDocument = DocumentLDA<_tw>; using DocumentLDA<_tw>::DocumentLDA; std::vector<Float> y; + + RawDoc::MiscType makeMisc(const ITopicModel* tm) const override + { + RawDoc::MiscType ret = DocumentLDA<_tw>::makeMisc(tm); + ret["y"] = y; + return ret; + } DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, y); DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, y); }; + struct SLDAArgs; + class ISLDAModel : public ILDAModel { public: enum class GLM { linear = 0, binary_logistic = 1, }; using DefaultDocType = DocumentSLDA<TermWeight::one>; - static ISLDAModel* create(TermWeight _weight, size_t _K = 1, - const std::vector<ISLDAModel::GLM>& vars = {}, - Float alpha = 0.1, Float _eta = 0.01, - const std::vector<Float>& _mu = {}, const std::vector<Float>& _nuSq = {}, - const std::vector<Float>& _glmParam = {}, - size_t seed = std::random_device{}(), + static ISLDAModel* create(TermWeight _weight, const SLDAArgs& args, bool scalarRng = false); virtual size_t getF() const = 0; virtual std::vector<Float> getRegressionCoef(size_t f) const = 0; virtual GLM getTypeOfVar(size_t f) const = 0; virtual std::vector<Float> estimateVars(const DocumentBase* doc) const = 0; + }; + + struct SLDAArgs : public LDAArgs + { + std::vector<ISLDAModel::GLM> vars; + std::vector<Float> mu; + std::vector<Float> nuSq; + std::vector<Float> glmParam; }; } \ No newline at end of file