vendor/tomotopy/src/TopicModel/PAModel.hpp in tomoto-0.1.3 vs vendor/tomotopy/src/TopicModel/PAModel.hpp in tomoto-0.1.4
- old
+ new
@@ -142,11 +142,11 @@
std::vector<std::future<void>> res = pool.enqueueToAll([&](size_t partitionId)
{
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
e = edd.vChunkOffset[partitionId];
- localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
+ localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
localData[partitionId].numByTopic = globalState.numByTopic;
localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2;
localData[partitionId].numByTopic2 = globalState.numByTopic2;
if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
});
@@ -155,12 +155,10 @@
}
template<ParallelScheme _ps, typename _ExtraDocData>
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
{
- std::vector<std::future<void>> res;
-
if (_ps == ParallelScheme::copy_merge)
{
tState = globalState;
globalState = localData[0];
for (size_t i = 1; i < pool.getNumWorkers(); ++i)
@@ -175,31 +173,23 @@
if (_tw != TermWeight::one)
{
globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0);
globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0);
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
}
-
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
- {
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
- {
- localData[i] = globalState;
- }));
- }
}
else if (_ps == ParallelScheme::partition)
{
+ std::vector<std::future<void>> res;
res = pool.enqueueToAll([&](size_t partitionId)
{
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
e = edd.vChunkOffset[partitionId];
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
});
for (auto& r : res) r.get();
- res.clear();
tState.numByTopic1_2 = globalState.numByTopic1_2;
globalState.numByTopic1_2 = localData[0].numByTopic1_2;
for (size_t i = 1; i < pool.getNumWorkers(); ++i)
{
@@ -207,23 +197,42 @@
}
// make all count being positive
if (_tw != TermWeight::one)
{
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
}
globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum();
globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum();
+ }
+ }
+
+
+ template<ParallelScheme _ps>
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
+ {
+ std::vector<std::future<void>> res;
+ if (_ps == ParallelScheme::copy_merge)
+ {
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
+ {
+ res.emplace_back(pool.enqueue([&, i](size_t)
+ {
+ localData[i] = globalState;
+ }));
+ }
+ }
+ else if (_ps == ParallelScheme::partition)
+ {
res = pool.enqueueToAll([&](size_t threadId)
{
localData[threadId].numByTopic = globalState.numByTopic;
localData[threadId].numByTopic1_2 = globalState.numByTopic1_2;
localData[threadId].numByTopic2 = globalState.numByTopic2;
});
}
-
for (auto& r : res) r.get();
}
template<typename _DocIter>
double getLLDocs(_DocIter _first, _DocIter _last) const
@@ -302,10 +311,11 @@
if (initDocs)
{
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2);
this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
+ this->globalState.numByTopicWord.init(nullptr, K2, V);
}
}
struct Generator
{
\ No newline at end of file