Sha256: 4b1d573f865ea25f075c14ef42738e2219fb2bcb69dcb4aac18a0f0a92ec86e6

Contents?: true

Size: 1.85 KB

Versions: 3

Compression:

Stored size: 1.85 KB

Contents

#pragma once
#include "LDA.h"

namespace tomoto
{
    template<TermWeight _tw>
	struct DocumentMGLDA : public DocumentLDA<_tw>
	{
		using BaseDocument = DocumentLDA<_tw>;
		using DocumentLDA<_tw>::DocumentLDA;
		using WeightType = typename DocumentLDA<_tw>::WeightType;

		std::vector<uint16_t> sents; // sentence id of each word (const)
		std::vector<WeightType> numBySent; // number of words in the sentence (const)

		//std::vector<Tid> Zs; // gl./loc. and topic assignment
		std::vector<uint8_t> Vs; // window assignment
		WeightType numGl = 0; // number of words assigned as gl.
		//std::vector<uint32_t> numByTopic; // len = K + KL
		Eigen::Matrix<WeightType, -1, -1> numBySentWin; // len = S * T
		Eigen::Matrix<WeightType, -1, 1> numByWinL; // number of words assigned as loc. in the window (len = S + T - 1)
		Eigen::Matrix<WeightType, -1, 1> numByWin; // number of words in the window (len = S + T - 1)
		Eigen::Matrix<WeightType, -1, -1> numByWinTopicL; // number of words in the loc. topic in the window (len = KL * (S + T - 1))

		DECLARE_SERIALIZER_WITH_VERSION(0);
		DECLARE_SERIALIZER_WITH_VERSION(1);

		template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
	};

	struct MGLDAArgs : public LDAArgs
	{
		size_t kL = 1;
		size_t t = 3;
		std::vector<Float> alphaL = { 0.1 };
		Float alphaMG = 0.1;
		Float alphaML = 0.1;
		Float etaL = 0.01;
		Float gamma = 0.1;
	};

	class IMGLDAModel : public ILDAModel
	{
	public:
		using DefaultDocType = DocumentMGLDA<TermWeight::one>;
		static IMGLDAModel* create(TermWeight _weight, const MGLDAArgs& args,
			bool scalarRng = false);

		virtual size_t getKL() const = 0;
		virtual size_t getT() const = 0;
		virtual Float getAlphaL() const = 0;
		virtual Float getEtaL() const = 0;
		virtual Float getGamma() const = 0;
		virtual Float getAlphaM() const = 0;
		virtual Float getAlphaML() const = 0;
	};
}

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
tomoto-0.5.1 vendor/tomotopy/src/TopicModel/MGLDA.h
tomoto-0.5.0 vendor/tomotopy/src/TopicModel/MGLDA.h
tomoto-0.4.1 vendor/tomotopy/src/TopicModel/MGLDA.h