vendor/tomotopy/src/TopicModel/SLDAModel.hpp in tomoto-0.1.2 vs vendor/tomotopy/src/TopicModel/SLDAModel.hpp in tomoto-0.1.3
- old
+ new
@@ -146,11 +146,11 @@
this->regressionCoef = normZZT
.colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, b * (ys.array() - 0.5f)).matrix()
+ Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
RandGen rng;
- for (size_t i = 0; i < omega.size(); ++i)
+ for (size_t i = 0; i < (size_t)omega.size(); ++i)
{
if (std::isnan(ys[i])) continue;
omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
}
}
@@ -356,67 +356,49 @@
_DocType& _updateDoc(_DocType& doc, const std::vector<Float>& y)
{
if (_const)
{
if (y.size() > F) throw std::runtime_error{ text::format(
- "size of 'y' is greater than the number of vars.\n"
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
+ "size of `y` is greater than the number of vars.\n"
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
doc.y = y;
while (doc.y.size() < F)
{
doc.y.emplace_back(NAN);
}
}
else
{
if (y.size() != F) throw std::runtime_error{ text::format(
- "size of 'y' must be equal to the number of vars.\n"
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
+ "size of `y` must be equal to the number of vars.\n"
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
doc.y = y;
}
return doc;
}
- size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override
+ size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
{
- auto doc = this->_makeDoc(words);
- return this->_addDoc(_updateDoc(doc, y));
+ auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
}
- std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
{
- auto doc = as_mutable(this)->template _makeDoc<true>(words);
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
}
- size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
- const std::vector<Float>& y) override
+ size_t addDoc(const RawDoc& rawDoc) override
{
- auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
- return this->_addDoc(_updateDoc(doc, y));
+ auto doc = this->_makeFromRawDoc(rawDoc);
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
}
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
- const std::vector<Float>& y) const override
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
{
- auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
- }
-
- size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
- const std::vector<Float>& y) override
- {
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
- return this->_addDoc(_updateDoc(doc, y));
- }
-
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
- const std::vector<Float>& y) const override
- {
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
}
std::vector<Float> estimateVars(const DocumentBase* doc) const override
{
std::vector<Float> ret;