Sha256: b786188ad0014ea94b8c6e53bf490a0dac08a54dcc4ed4a73d5e054c54213f87

Contents?: true

Size: 1.82 KB

Versions: 3

Compression:

Stored size: 1.82 KB

Contents

#pragma once
#include "LDA.h"

namespace tomoto
{
	template<TermWeight _tw>
	struct DocumentLLDA : public DocumentLDA<_tw>
	{
		using BaseDocument = DocumentLDA<_tw>;
		using DocumentLDA<_tw>::DocumentLDA;
		using WeightType = typename DocumentLDA<_tw>::WeightType;
		Eigen::Matrix<int8_t, -1, 1> labelMask;

		DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, labelMask);
		DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, labelMask);
	};

	class ILLDAModel : public ILDAModel
	{
	public:
		using DefaultDocType = DocumentLLDA<TermWeight::one>;
		static ILLDAModel* create(TermWeight _weight, size_t _K = 1, 
			Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
			bool scalarRng = false);

		virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) = 0;
		virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) const = 0;

		virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer, 
			const std::vector<std::string>& label) = 0;
		virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
			const std::vector<std::string>& label) const = 0;

		virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
			const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, 
			const std::vector<std::string>& label) = 0;
		virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
			const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
			const std::vector<std::string>& label) const = 0;

		virtual const Dictionary& getTopicLabelDict() const = 0;

		virtual size_t getNumTopicsPerLabel() const = 0;
	};
}

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
tomoto-0.1.2 vendor/tomotopy/src/TopicModel/LLDA.h
tomoto-0.1.1 vendor/tomotopy/src/TopicModel/LLDA.h
tomoto-0.1.0 vendor/tomotopy/src/TopicModel/LLDA.h