vendor/tomotopy/src/TopicModel/PAModel.hpp in tomoto-0.1.3 vs vendor/tomotopy/src/TopicModel/PAModel.hpp in tomoto-0.1.4

- old
+ new

@@ -142,11 +142,11 @@ std::vector<std::future<void>> res = pool.enqueueToAll([&](size_t partitionId) { size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0, e = edd.vChunkOffset[partitionId]; - localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b); + localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b); localData[partitionId].numByTopic = globalState.numByTopic; localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2; localData[partitionId].numByTopic2 = globalState.numByTopic2; if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood; }); @@ -155,12 +155,10 @@ } template<ParallelScheme _ps, typename _ExtraDocData> void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const { - std::vector<std::future<void>> res; - if (_ps == ParallelScheme::copy_merge) { tState = globalState; globalState = localData[0]; for (size_t i = 1; i < pool.getNumWorkers(); ++i) @@ -175,31 +173,23 @@ if (_tw != TermWeight::one) { globalState.numByTopic = globalState.numByTopic.cwiseMax(0); globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0); globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0); - globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0); + globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0); } - - for (size_t i = 0; i < pool.getNumWorkers(); ++i) - { - res.emplace_back(pool.enqueue([&, this, i](size_t threadId) - { - localData[i] = globalState; - })); - } } else if (_ps == ParallelScheme::partition) { + std::vector<std::future<void>> res; res = pool.enqueueToAll([&](size_t partitionId) { size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0, e = edd.vChunkOffset[partitionId]; globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord; }); for (auto& r : res) r.get(); - res.clear(); tState.numByTopic1_2 = globalState.numByTopic1_2; globalState.numByTopic1_2 = localData[0].numByTopic1_2; for (size_t i = 1; i < pool.getNumWorkers(); ++i) { @@ -207,23 +197,42 @@ } // make all count being positive if (_tw != TermWeight::one) { - globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0); + globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0); } globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum(); globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum(); + } + } + + + template<ParallelScheme _ps> + void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const + { + std::vector<std::future<void>> res; + if (_ps == ParallelScheme::copy_merge) + { + for (size_t i = 0; i < pool.getNumWorkers(); ++i) + { + res.emplace_back(pool.enqueue([&, i](size_t) + { + localData[i] = globalState; + })); + } + } + else if (_ps == ParallelScheme::partition) + { res = pool.enqueueToAll([&](size_t threadId) { localData[threadId].numByTopic = globalState.numByTopic; localData[threadId].numByTopic1_2 = globalState.numByTopic1_2; localData[threadId].numByTopic2 = globalState.numByTopic2; }); } - for (auto& r : res) r.get(); } template<typename _DocIter> double getLLDocs(_DocIter _first, _DocIter _last) const @@ -302,10 +311,11 @@ if (initDocs) { this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K); this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2); this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2); - this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V); + //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V); + this->globalState.numByTopicWord.init(nullptr, K2, V); } } struct Generator { \ No newline at end of file