vendor/tomotopy/src/TopicModel/DMRModel.hpp in tomoto-0.1.4 vs vendor/tomotopy/src/TopicModel/DMRModel.hpp in tomoto-0.2.0

- old
+ new

@@ -11,12 +11,26 @@ namespace tomoto { template<TermWeight _tw> struct ModelStateDMR : public ModelStateLDA<_tw> { - Eigen::Matrix<Float, -1, 1> tmpK; + Vector tmpK; }; + + struct MdHash + { + size_t operator()(std::pair<uint64_t, Vector> const& p) const + { + size_t seed = p.first; + for (size_t i = 0; i < p.second.size(); ++i) + { + auto elem = p.second[i]; + seed ^= std::hash<decltype(elem)>()(elem) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } + }; template<TermWeight _tw, typename _RandGen, size_t _Flags = flags::partitioned_multisampling, typename _Interface = IDMRModel, typename _Derived = void, @@ -33,153 +47,170 @@ friend typename BaseClass::BaseClass; using WeightType = typename BaseClass::WeightType; static constexpr char TMID[] = "DMR\0"; - Eigen::Matrix<Float, -1, -1> lambda; - Eigen::Matrix<Float, -1, -1> expLambda; + Matrix lambda; + mutable std::unordered_map<std::pair<uint64_t, Vector>, size_t, MdHash> mdHashMap; + mutable Matrix cachedAlphas; Float sigma; - uint32_t F = 0; + uint32_t F = 0, mdVecSize = 1; uint32_t optimRepeat = 5; Float alphaEps = 1e-10; - Float temperatureScale = 0; static constexpr Float maxLambda = 10; static constexpr size_t maxBFGSIteration = 10; Dictionary metadataDict; + Dictionary multiMetadataDict; LBFGSpp::LBFGSSolver<Float, LBFGSpp::LineSearchBracketing> solver; - Float getNegativeLambdaLL(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) const + Float getNegativeLambdaLL(Eigen::Ref<Vector> x, Vector& g) const { g = (x.array() - log(this->alpha)) / pow(sigma, 2); return (x.array() - log(this->alpha)).pow(2).sum() / 2 / pow(sigma, 2); } - Float evaluateLambdaObj(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g, ThreadPool& pool, _ModelState* localData) const + Float evaluateLambdaObj(Eigen::Ref<Vector> x, Vector& g, ThreadPool& pool, _ModelState* localData) const { // if one of x is greater than maxLambda, return +inf for preventing searching more if ((x.array() > maxLambda).any()) return INFINITY; const auto K = this->K; - Float fx = - static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g); - auto alphas = (x.array().exp() + alphaEps).eval(); + Float fx = -static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g); + Eigen::Map<Matrix> xReshaped{ x.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) }; - std::vector<std::future<Eigen::Matrix<Float, -1, 1>>> res; + std::vector<std::future<Eigen::Array<Float, -1, 1>>> res; const size_t chStride = pool.getNumWorkers() * 8; for (size_t ch = 0; ch < chStride; ++ch) { res.emplace_back(pool.enqueue([&](size_t threadId) { auto& tmpK = localData[threadId].tmpK; if (!tmpK.size()) tmpK.resize(this->K); - Eigen::Matrix<Float, -1, 1> val = Eigen::Matrix<Float, -1, 1>::Zero(K * F + 1); + Eigen::Array<Float, -1, 1> val = Eigen::Array<Float, -1, 1>::Zero(K * F * mdVecSize + 1); + Eigen::Map<Matrix> grad{ val.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) }; + Float& fx = val[K * F * mdVecSize]; for (size_t docId = ch; docId < this->docs.size(); docId += chStride) { const auto& doc = this->docs[docId]; - auto alphaDoc = alphas.segment(doc.metadata * K, K); + auto alphaDoc = ((xReshaped.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps).matrix().eval(); Float alphaSum = alphaDoc.sum(); for (Tid k = 0; k < K; ++k) { - val[K * F] -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]); + fx -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]); if (!std::isfinite(alphaDoc[k]) && alphaDoc[k] > 0) tmpK[k] = 0; else tmpK[k] = -(math::digammaT(alphaDoc[k]) - math::digammaT(doc.numByTopic[k] + alphaDoc[k])); } - //val[K * F] = -(lgammaApprox(alphaDoc.array()) - lgammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array())).sum(); - //tmpK = -(digammaApprox(alphaDoc.array()) - digammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array())); - val[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum); + fx += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum); Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum); if (!std::isfinite(alphaSum) && alphaSum > 0) { - val[K * F] = -INFINITY; + fx = -INFINITY; t = 0; } - val.segment(doc.metadata * K, K).array() -= alphaDoc.array() * (tmpK.array() + t); + grad.middleCols(doc.metadata * mdVecSize, mdVecSize) -= (alphaDoc.array() * (tmpK.array() + t)).matrix() * doc.mdVec.transpose(); } return val; })); } for (auto& r : res) { auto ret = r.get(); - fx += ret[K * F]; - g += ret.head(K * F); + fx += ret[K * F * mdVecSize]; + g += ret.head(K * F * mdVecSize).matrix(); } // positive fx is an error from limited precision of float. if (fx > 0) return INFINITY; return -fx; } void initParameters() { - auto dist = std::normal_distribution<Float>(log(this->alpha), sigma); - for (size_t i = 0; i < this->K; ++i) for (size_t j = 0; j < F; ++j) + lambda = Eigen::Rand::normalLike(lambda, this->rg, 0, sigma); + for (size_t f = 0; f < F; ++f) { - lambda(i, j) = dist(this->rg); + lambda.col(f * mdVecSize) += this->alphas.array().log().matrix(); } } void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs) { - Eigen::Matrix<Float, -1, -1> bLambda; + Matrix bLambda; Float fx = 0, bestFx = INFINITY; for (size_t i = 0; i < optimRepeat; ++i) { static_cast<DerivedClass*>(this)->initParameters(); - int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) + int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Vector> x, Vector& g) { return static_cast<DerivedClass*>(this)->evaluateLambdaObj(x, g, pool, localData); - }, Eigen::Map<Eigen::Matrix<Float, -1, 1>>(lambda.data(), lambda.size()), fx); + }, Eigen::Map<Vector>(lambda.data(), lambda.size()), fx); if (fx < bestFx) { bLambda = lambda; bestFx = fx; //printf("\t(%d) %e\n", ret, fx); } } if (!std::isfinite(bestFx)) { - throw exception::TrainingError{ "optimizing parameters has been failed!" }; + throw exc::TrainingError{ "optimizing parameters has been failed!" }; } lambda = bLambda; + updateCachedAlphas(); //std::cerr << fx << std::endl; - expLambda = lambda.array().exp() + alphaEps; } - int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs) + int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs) { std::cerr << "Failed to optimize! Reset prior and retry!" << std::endl; lambda.setZero(); - expLambda = lambda.array().exp() + alphaEps; + updateCachedAlphas(); return 0; } + auto getCachedAlpha(const _DocType& doc) const + { + if (doc.mdHash < cachedAlphas.cols()) + { + return cachedAlphas.col(doc.mdHash); + } + else + { + if (!doc.cachedAlpha.size()) + { + doc.cachedAlpha = (lambda.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps; + } + return doc.cachedAlpha.col(0); + } + } + template<bool _asymEta> Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const { const size_t V = this->realV; assert(vid < V); auto etaHelper = this->template getEtaHelper<_asymEta>(); + auto alphas = getCachedAlpha(doc); auto& zLikelihood = ld.zLikelihood; - zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->expLambda.col(doc.metadata).array()) + zLikelihood = (doc.numByTopic.array().template cast<Float>() + alphas.array()) * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid)) / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum()); sample::prefixSum(zLikelihood.data(), this->K); return &zLikelihood[0]; } - double getLLDocTopic(const _DocType& doc) const { const size_t V = this->realV; const auto K = this->K; - auto alphaDoc = expLambda.col(doc.metadata); - + auto alphaDoc = getCachedAlpha(doc); + Float ll = 0; Float alphaSum = alphaDoc.sum(); for (Tid k = 0; k < K; ++k) { ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]); @@ -197,11 +228,11 @@ double ll = 0; for (; _first != _last; ++_first) { auto& doc = *_first; - auto alphaDoc = expLambda.col(doc.metadata); + auto alphaDoc = getCachedAlpha(doc); Float alphaSum = alphaDoc.sum(); for (Tid k = 0; k < K; ++k) { ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]) - math::lgammaT(alphaDoc[k]); @@ -232,79 +263,180 @@ } } return ll; } + void updateCachedAlphas() const + { + cachedAlphas.resize(this->K, mdHashMap.size()); + + for (auto& p : mdHashMap) + { + cachedAlphas.col(p.second) = (lambda.middleCols(p.first.first * mdVecSize, mdVecSize) * p.first.second).array().exp() + alphaEps; + } + } + + void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const + { + BaseClass::prepareDoc(doc, docId, wordSize); + + doc.mdVec = Vector::Zero(mdVecSize); + doc.mdVec[0] = 1; + for (auto x : doc.multiMetadata) + { + doc.mdVec[x + 1] = 1; + } + + auto p = std::make_pair(doc.metadata, doc.mdVec); + auto it = mdHashMap.find(p); + if (it == mdHashMap.end()) + { + it = mdHashMap.emplace(p, mdHashMap.size()).first; + } + doc.mdHash = it->second; + } + void initGlobalState(bool initDocs) { BaseClass::initGlobalState(initDocs); - this->globalState.tmpK = Eigen::Matrix<Float, -1, 1>::Zero(this->K); + this->globalState.tmpK = Vector::Zero(this->K); F = metadataDict.size(); + mdVecSize = multiMetadataDict.size() + 1; if (initDocs) { - lambda = Eigen::Matrix<Float, -1, -1>::Constant(this->K, F, log(this->alpha)); + lambda.resize(this->K, F * mdVecSize); + for (size_t f = 0; f < F; ++f) + { + lambda.col(f * mdVecSize) = this->alphas.array().log(); + lambda.middleCols(f * mdVecSize + 1, mdVecSize - 1).setZero(); + } } + else + { + for (auto& doc : this->docs) + { + if (doc.mdVec.size() == mdVecSize) continue; + doc.mdVec = Vector::Zero(mdVecSize); + doc.mdVec[0] = 1; + for (auto x : doc.multiMetadata) + { + doc.mdVec[x + 1] = 1; + } + + auto p = std::make_pair(doc.metadata, doc.mdVec); + auto it = this->mdHashMap.find(p); + if (it == this->mdHashMap.end()) + { + it = this->mdHashMap.emplace(p, mdHashMap.size()).first; + } + doc.mdHash = it->second; + } + } + if (_Flags & flags::continuous_doc_data) this->numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, this->docs.size()); - expLambda = lambda.array().exp(); LBFGSpp::LBFGSParam<Float> param; param.max_iterations = maxBFGSIteration; solver = decltype(solver){ param }; } + void prepareShared() + { + BaseClass::prepareShared(); + + for (auto doc : this->docs) + { + if (doc.mdHash != (size_t)-1) continue; + + auto p = std::make_pair(doc.metadata, doc.mdVec); + auto it = mdHashMap.find(p); + if (it == mdHashMap.end()) + { + it = mdHashMap.emplace(p, mdHashMap.size()).first; + } + doc.mdHash = it->second; + } + + updateCachedAlphas(); + } + public: DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma, alphaEps, metadataDict, lambda); - DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda); + DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda, multiMetadataDict); - DMRModel(size_t _K = 1, Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01, - Float _alphaEps = 0, size_t _rg = std::random_device{}()) - : BaseClass(_K, defaultAlpha, _eta, _rg), sigma(_sigma), alphaEps(_alphaEps) + DMRModel(const DMRArgs& args) + : BaseClass(args), sigma(args.sigma), alphaEps(args.alphaEps) { - if (_sigma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong sigma value (sigma = %f)", _sigma)); + if (sigma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong sigma value (sigma = %f)", sigma)); } template<bool _const = false> - _DocType& _updateDoc(_DocType& doc, const std::string& metadata) + _DocType& _updateDoc(_DocType& doc, const std::string& metadata, const std::vector<std::string>& mdVec = {}) { Vid xid; if (_const) { xid = metadataDict.toWid(metadata); - if (xid == (Vid)-1) throw std::invalid_argument("unknown metadata"); + if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'"); + + for (auto& m : mdVec) + { + Vid x = multiMetadataDict.toWid(m); + if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'"); + doc.multiMetadata.emplace_back(x); + } } else { xid = metadataDict.add(metadata); + + for (auto& m : mdVec) + { + doc.multiMetadata.emplace_back(multiMetadataDict.add(m)); + } } doc.metadata = xid; return doc; } size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override { auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer); - return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::string>("metadata"))); + return this->_addDoc(_updateDoc(doc, + rawDoc.template getMisc<std::string>("metadata"), + rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata") + )); } std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override { auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer); - return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::string>("metadata"))); + return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, + rawDoc.template getMisc<std::string>("metadata"), + rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata") + )); } size_t addDoc(const RawDoc& rawDoc) override { auto doc = this->_makeFromRawDoc(rawDoc); - return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::string>("metadata"))); + return this->_addDoc(_updateDoc(doc, + rawDoc.template getMisc<std::string>("metadata"), + rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata") + )); } std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override { auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc); - return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::string>("metadata"))); + return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, + rawDoc.template getMisc<std::string>("metadata"), + rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata") + )); } GETTER(F, size_t, F); + GETTER(MdVecSize, size_t, mdVecSize); GETTER(Sigma, Float, sigma); GETTER(AlphaEps, Float, alphaEps); GETTER(OptimRepeat, size_t, optimRepeat); void setAlphaEps(Float _alphaEps) override @@ -315,33 +447,75 @@ void setOptimRepeat(size_t _optimRepeat) override { optimRepeat = _optimRepeat; } - std::vector<Float> getTopicsByDoc(const _DocType& doc) const + std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const { std::vector<Float> ret(this->K); - auto alphaDoc = expLambda.col(doc.metadata); - Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() = - (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum()); + auto alphaDoc = getCachedAlpha(doc); + Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K }; + if (normalize) + { + m = (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum()); + } + else + { + m = doc.numByTopic.array().template cast<Float>() + alphaDoc.array(); + } return ret; } std::vector<Float> getLambdaByMetadata(size_t metadataId) const override { assert(metadataId < metadataDict.size()); auto l = lambda.col(metadataId); - return { l.data(), l.data() + this->K }; + return { l.data(), l.data() + l.size() }; } std::vector<Float> getLambdaByTopic(Tid tid) const override { - assert(tid < this->K); - auto l = lambda.row(tid); - return { l.data(), l.data() + F }; + std::vector<Float> ret(F * mdVecSize); + if (this->lambda.size()) + { + Eigen::Map<Vector>{ ret.data(), (Eigen::Index)ret.size() } = this->lambda.row(tid); + } + return ret; } + std::vector<Float> getTopicPrior(const std::string& metadata, + const std::vector<std::string>& mdVec, + bool raw = false + ) const override + { + Vid xid = metadataDict.toWid(metadata); + if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'"); + + Vector xs = Vector::Zero(mdVecSize); + xs[0] = 1; + for (auto& m : mdVec) + { + Vid x = multiMetadataDict.toWid(m); + if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'"); + xs[x + 1] = 1; + } + + std::vector<Float> ret(this->K); + Eigen::Map<Vector> map{ ret.data(), (Eigen::Index)ret.size() }; + + if (raw) + { + map = lambda.middleCols(xid * mdVecSize, mdVecSize) * xs; + } + else + { + map = (lambda.middleCols(xid * mdVecSize, mdVecSize) * xs).array().exp() + alphaEps; + } + return ret; + } + const Dictionary& getMetadataDict() const override { return metadataDict; } + const Dictionary& getMultiMetadataDict() const override { return multiMetadataDict; } }; /* This is for preventing 'undefined symbol' problem in compiling by clang. */ template<TermWeight _tw, typename _RandGen, size_t _Flags, typename _Interface, typename _Derived, typename _DocType, typename _ModelState>