vendor/tomotopy/src/TopicModel/MGLDAModel.hpp in tomoto-0.1.4 vs vendor/tomotopy/src/TopicModel/MGLDAModel.hpp in tomoto-0.2.0
- old
+ new
@@ -287,11 +287,11 @@
doc.sents[doc.wOrder[i]] = tmp[i];
}
const size_t S = doc.numBySent.size();
std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0);
- doc.Zs = tvector<Tid>(wordSize);
+ doc.Zs = tvector<Tid>(wordSize, non_topic_id);
doc.Vs.resize(wordSize);
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
doc.numByTopic.init(nullptr, this->K + KL, 1);
doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
@@ -300,11 +300,11 @@
}
void initGlobalState(bool initDocs)
{
const size_t V = this->realV;
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(T * (this->K + KL));
+ this->globalState.zLikelihood = Vector::Zero(T * (this->K + KL));
if (initDocs)
{
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
this->globalState.numByTopicWord.init(nullptr, this->K + KL, V);
@@ -369,21 +369,37 @@
public:
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
- MGLDAModel(size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
- Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
- Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
- : BaseClass(_KG, _alphaG, _etaG, _rg), KL(_KL), T(_T),
- alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML),
- etaL(_etaL), gamma(_gamma)
+ MGLDAModel(const MGLDAArgs& args)
+ : BaseClass(args), KL(args.kL), T(args.t),
+ alphaL(args.alphaL[0]), alphaM(args.k ? args.alphaMG : 0), alphaML(args.alphaML),
+ etaL(args.etaL), gamma(args.gamma)
{
- if (_KL == 0 || _KL >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong KL value (KL = %zd)", _KL));
- if (_T == 0 || _T >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong T value (T = %zd)", _T));
- if (_alphaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alphaL value (alphaL = %f)", _alphaL));
- if (_etaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong etaL value (etaL = %f)", _etaL));
+ if (KL == 0 || KL >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong KL value (KL = %zd)", KL));
+ if (T == 0 || T >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong T value (T = %zd)", T));
+
+ if (args.alpha.size() != 1)
+ {
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alpha prior is not supported yet at MGLDA.");
+ }
+
+ if (args.alphaL.size() == 1)
+ {
+ }
+ else if (args.alphaL.size() == args.kL)
+ {
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alphaL prior is not supported yet at MGLDA.");
+ }
+ else
+ {
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (len = %zd)", args.alphaL.size()));
+ }
+
+ if (alphaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (alphaL = %f)", alphaL));
+ if (etaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong etaL value (etaL = %f)", etaL));
}
template<bool _const, typename _FnTokenizer>
_DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer, const std::string& delimiter)
{
@@ -424,11 +440,11 @@
return this->_addDoc(_makeFromRawDoc<false>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
}
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const
{
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
}
template<bool _const = false>
_DocType _makeFromRawDoc(const RawDoc& rawDoc)
{
@@ -495,28 +511,35 @@
return this->_addDoc(_makeFromRawDoc(rawDoc));
}
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const
{
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
}
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
{
- if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
+ if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K.");
for (auto p : priors)
{
- if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
+ if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
}
this->dict.add(word);
this->etaByWord.emplace(word, priors);
}
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
{
std::vector<Float> ret(this->K + KL);
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K + KL }.array() =
- doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
+ if (normalize)
+ {
+ m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
+ }
+ else
+ {
+ m = doc.numByTopic.array().template cast<Float>();
+ }
return ret;
}
GETTER(KL, size_t, KL);
GETTER(T, size_t, T);