vendor/tomotopy/src/TopicModel/MGLDAModel.hpp in tomoto-0.1.4 vs vendor/tomotopy/src/TopicModel/MGLDAModel.hpp in tomoto-0.2.0

- old
+ new

@@ -287,11 +287,11 @@ doc.sents[doc.wOrder[i]] = tmp[i]; } const size_t S = doc.numBySent.size(); std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0); - doc.Zs = tvector<Tid>(wordSize); + doc.Zs = tvector<Tid>(wordSize, non_topic_id); doc.Vs.resize(wordSize); if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize); doc.numByTopic.init(nullptr, this->K + KL, 1); doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T); doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1); @@ -300,11 +300,11 @@ } void initGlobalState(bool initDocs) { const size_t V = this->realV; - this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(T * (this->K + KL)); + this->globalState.zLikelihood = Vector::Zero(T * (this->K + KL)); if (initDocs) { this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL); //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V); this->globalState.numByTopicWord.init(nullptr, this->K + KL, V); @@ -369,21 +369,37 @@ public: DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T); DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T); - MGLDAModel(size_t _KG = 1, size_t _KL = 1, size_t _T = 3, - Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1, - Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}()) - : BaseClass(_KG, _alphaG, _etaG, _rg), KL(_KL), T(_T), - alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML), - etaL(_etaL), gamma(_gamma) + MGLDAModel(const MGLDAArgs& args) + : BaseClass(args), KL(args.kL), T(args.t), + alphaL(args.alphaL[0]), alphaM(args.k ? args.alphaMG : 0), alphaML(args.alphaML), + etaL(args.etaL), gamma(args.gamma) { - if (_KL == 0 || _KL >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong KL value (KL = %zd)", _KL)); - if (_T == 0 || _T >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong T value (T = %zd)", _T)); - if (_alphaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alphaL value (alphaL = %f)", _alphaL)); - if (_etaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong etaL value (etaL = %f)", _etaL)); + if (KL == 0 || KL >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong KL value (KL = %zd)", KL)); + if (T == 0 || T >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong T value (T = %zd)", T)); + + if (args.alpha.size() != 1) + { + THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alpha prior is not supported yet at MGLDA."); + } + + if (args.alphaL.size() == 1) + { + } + else if (args.alphaL.size() == args.kL) + { + THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alphaL prior is not supported yet at MGLDA."); + } + else + { + THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (len = %zd)", args.alphaL.size())); + } + + if (alphaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (alphaL = %f)", alphaL)); + if (etaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong etaL value (etaL = %f)", etaL)); } template<bool _const, typename _FnTokenizer> _DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer, const std::string& delimiter) { @@ -424,11 +440,11 @@ return this->_addDoc(_makeFromRawDoc<false>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter"))); } std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const { - return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter"))); + return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter"))); } template<bool _const = false> _DocType _makeFromRawDoc(const RawDoc& rawDoc) { @@ -495,28 +511,35 @@ return this->_addDoc(_makeFromRawDoc(rawDoc)); } std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const { - return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc)); + return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc)); } void setWordPrior(const std::string& word, const std::vector<Float>& priors) override { - if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K."); + if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K."); for (auto p : priors) { - if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0."); + if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0."); } this->dict.add(word); this->etaByWord.emplace(word, priors); } - std::vector<Float> getTopicsByDoc(const _DocType& doc) const + std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const { std::vector<Float> ret(this->K + KL); - Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K + KL }.array() = - doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight(); + Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL }; + if (normalize) + { + m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight(); + } + else + { + m = doc.numByTopic.array().template cast<Float>(); + } return ret; } GETTER(KL, size_t, KL); GETTER(T, size_t, T);