vendor/kmeans/MiniBatch.hpp in umappp-0.1.6 vs vendor/kmeans/MiniBatch.hpp in umappp-0.2.0
- old
+ new
@@ -51,33 +51,38 @@
/**
* @brief Default parameter values for `MiniBatch`.
*/
struct Defaults {
/**
- * See `MiniBatch::set_max_iterations()`.
+ * See `set_max_iterations()` for more details.
*/
static constexpr int max_iterations = 100;
/**
- * See `MiniBatch::set_batch_size()`.
+ * See `set_batch_size()` for more details.
*/
static constexpr INDEX_t batch_size = 500;
/**
- * See `MiniBatch::set_max_change_proportion()`.
+ * See `set_max_change_proportion()` for more details.
*/
static constexpr double max_change_proportion = 0.01;
/**
- * See `MiniBatch::set_convergence_history()`.
+ * See `set_convergence_history()` for more details.
*/
static constexpr int convergence_history = 10;
/**
- * See `MiniBatch::set_seed()`.
+ * See `set_seed()` for more details.
*/
static constexpr uint64_t seed = 1234567890;
+
+ /**
+ * See `set_num_threads()` for more details.
+ */
+ static constexpr int num_threads = 1;
};
private:
int maxiter = Defaults::max_iterations;
@@ -86,10 +91,12 @@
int history = Defaults::convergence_history;
double max_change = Defaults::max_change_proportion;
uint64_t seed = Defaults::seed;
+
+ int nthreads = Defaults::num_threads;
public:
/**
* @param i Maximum number of iterations.
* More iterations increase the opportunity for convergence at the cost of more computational time.
*
@@ -141,10 +148,20 @@
MiniBatch& set_seed(uint64_t s = Defaults::seed) {
seed = s;
return *this;
}
+ /**
+ * @param n Number of threads to use.
+ *
+ * @return A reference to this `MiniBatch` object.
+ */
+ MiniBatch& set_num_threads(int n = Defaults::num_threads) {
+ nthreads = n;
+ return *this;
+ }
+
public:
/**
* @param ndim Number of dimensions.
* @param nobs Number of observations.
* @param[in] data Pointer to a `ndim`-by-`nobs` array where columns are observations and rows are dimensions.
@@ -181,14 +198,26 @@
previous[o] = clusters[o];
}
}
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
- #pragma omp parallel for
- for (size_t i = 0; i < chosen.size(); ++i) {
+ size_t nchosen = chosen.size();
+
+#ifndef KMEANS_CUSTOM_PARALLEL
+ #pragma omp parallel for num_threads(nthreads)
+ for (size_t i = 0; i < nchosen; ++i) {
+#else
+ KMEANS_CUSTOM_PARALLEL(nchosen, [&](size_t first, size_t last) -> void {
+ for (size_t i = first; i < last; ++i) {
+#endif
clusters[chosen[i]] = index.find(data + chosen[i] * ndim);
+#ifndef KMEANS_CUSTOM_PARALLEL
}
+#else
+ }
+ }, nthreads);
+#endif
// Updating the means for each cluster.
for (auto o : chosen) {
const auto c = clusters[o];
auto& n = total_sampled[c];
@@ -234,13 +263,24 @@
status = 2;
}
// Run through all observations to make sure they have the latest cluster assignments.
QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers);
- #pragma omp parallel for
+
+#ifndef KMEANS_CUSTOM_PARALLEL
+ #pragma omp parallel for num_threads(nthreads)
for (INDEX_t o = 0; o < nobs; ++o) {
+#else
+ KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t last) -> void {
+ for (INDEX_t o = first; o < last; ++o) {
+#endif
clusters[o] = index.find(data + o * ndim);
+#ifndef KMEANS_CUSTOM_PARALLEL
}
+#else
+ }
+ }, nthreads);
+#endif
std::fill(total_sampled.begin(), total_sampled.end(), 0);
for (INDEX_t o = 0; o < nobs; ++o) {
++total_sampled[clusters[o]];
}