vendor/tomotopy/src/TopicModel/SLDAModel.hpp in tomoto-0.1.2 vs vendor/tomotopy/src/TopicModel/SLDAModel.hpp in tomoto-0.1.3

- old
+ new

@@ -146,11 +146,11 @@ this->regressionCoef = normZZT .colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, b * (ys.array() - 0.5f)).matrix() + Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq)); RandGen rng; - for (size_t i = 0; i < omega.size(); ++i) + for (size_t i = 0; i < (size_t)omega.size(); ++i) { if (std::isnan(ys[i])) continue; omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng); } } @@ -356,67 +356,49 @@ _DocType& _updateDoc(_DocType& doc, const std::vector<Float>& y) { if (_const) { if (y.size() > F) throw std::runtime_error{ text::format( - "size of 'y' is greater than the number of vars.\n" - "size of 'y' : %zd, number of vars: %zd", y.size(), F) }; + "size of `y` is greater than the number of vars.\n" + "size of `y` : %zd, number of vars: %zd", y.size(), F) }; doc.y = y; while (doc.y.size() < F) { doc.y.emplace_back(NAN); } } else { if (y.size() != F) throw std::runtime_error{ text::format( - "size of 'y' must be equal to the number of vars.\n" - "size of 'y' : %zd, number of vars: %zd", y.size(), F) }; + "size of `y` must be equal to the number of vars.\n" + "size of `y` : %zd, number of vars: %zd", y.size(), F) }; doc.y = y; } return doc; } - size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override + size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override { - auto doc = this->_makeDoc(words); - return this->_addDoc(_updateDoc(doc, y)); + auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer); + return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y"))); } - std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override + std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override { - auto doc = as_mutable(this)->template _makeDoc<true>(words); - return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y)); + auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer); + return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y"))); } - size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, - const std::vector<Float>& y) override + size_t addDoc(const RawDoc& rawDoc) override { - auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer); - return this->_addDoc(_updateDoc(doc, y)); + auto doc = this->_makeFromRawDoc(rawDoc); + return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y"))); } - std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, - const std::vector<Float>& y) const override + std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override { - auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer); - return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y)); - } - - size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words, - const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, - const std::vector<Float>& y) override - { - auto doc = this->_makeRawDoc(rawStr, words, pos, len); - return this->_addDoc(_updateDoc(doc, y)); - } - - std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words, - const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, - const std::vector<Float>& y) const override - { - auto doc = this->_makeRawDoc(rawStr, words, pos, len); - return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y)); + auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc); + return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y"))); } std::vector<Float> estimateVars(const DocumentBase* doc) const override { std::vector<Float> ret;