vendor/kmeans/InitializeKmeansPP.hpp in umappp-0.1.6 vs vendor/kmeans/InitializeKmeansPP.hpp in umappp-0.2.0
- old
+ new
@@ -42,10 +42,15 @@
struct Defaults {
/**
* See `set_seed()` for more details.
*/
static constexpr uint64_t seed = 6523u;
+
+ /**
+ * See `set_num_threads()` for more details.
+ */
+ static constexpr int num_threads = 1;
};
/**
* @param Random seed to use to construct the PRNG prior to sampling.
*
@@ -53,12 +58,24 @@
*/
InitializeKmeansPP& set_seed(uint64_t s = Defaults::seed) {
seed = s;
return *this;
}
+
+ /**
+ * @param n Number of threads to use.
+ *
+ * @return A reference to this `InitializeKmeansPP` object.
+ */
+ InitializeKmeansPP& set_num_threads(int n = Defaults::num_threads) {
+ nthreads = n;
+ return *this;
+ }
+
private:
uint64_t seed = Defaults::seed;
+ int nthreads = Defaults::num_threads;
public:
/**
* @cond
*/
@@ -72,12 +89,17 @@
for (CLUSTER_t cen = 0; cen < ncenters; ++cen) {
INDEX_t counter = 0;
if (!sofar.empty()) {
auto last = sofar.back();
- #pragma omp parallel for
+#ifndef KMEANS_CUSTOM_PARALLEL
+ #pragma omp parallel for num_threads(nthreads)
for (INDEX_t obs = 0; obs < nobs; ++obs) {
+#else
+ KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t end) -> void {
+ for (INDEX_t obs = first; obs < end; ++obs) {
+#endif
if (mindist[obs]) {
const DATA_t* acopy = data + obs * ndim;
const DATA_t* scopy = data + last * ndim;
DATA_t r2 = 0;
for (int dim = 0; dim < ndim; ++dim, ++acopy, ++scopy) {
@@ -86,10 +108,16 @@
if (cen == 1 || r2 < mindist[obs]) {
mindist[obs] = r2;
}
}
+#ifndef KMEANS_CUSTOM_PARALLEL
}
+#else
+ }
+ }, nthreads);
+#endif
+
} else {
counter = nobs;
}
cumulative[0] = mindist[0];