vendor/kmeans/InitializeKmeansPP.hpp in umappp-0.1.6 vs vendor/kmeans/InitializeKmeansPP.hpp in umappp-0.2.0

- old
+ new

@@ -42,10 +42,15 @@ struct Defaults { /** * See `set_seed()` for more details. */ static constexpr uint64_t seed = 6523u; + + /** + * See `set_num_threads()` for more details. + */ + static constexpr int num_threads = 1; }; /** * @param Random seed to use to construct the PRNG prior to sampling. * @@ -53,12 +58,24 @@ */ InitializeKmeansPP& set_seed(uint64_t s = Defaults::seed) { seed = s; return *this; } + + /** + * @param n Number of threads to use. + * + * @return A reference to this `InitializeKmeansPP` object. + */ + InitializeKmeansPP& set_num_threads(int n = Defaults::num_threads) { + nthreads = n; + return *this; + } + private: uint64_t seed = Defaults::seed; + int nthreads = Defaults::num_threads; public: /** * @cond */ @@ -72,12 +89,17 @@ for (CLUSTER_t cen = 0; cen < ncenters; ++cen) { INDEX_t counter = 0; if (!sofar.empty()) { auto last = sofar.back(); - #pragma omp parallel for +#ifndef KMEANS_CUSTOM_PARALLEL + #pragma omp parallel for num_threads(nthreads) for (INDEX_t obs = 0; obs < nobs; ++obs) { +#else + KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t end) -> void { + for (INDEX_t obs = first; obs < end; ++obs) { +#endif if (mindist[obs]) { const DATA_t* acopy = data + obs * ndim; const DATA_t* scopy = data + last * ndim; DATA_t r2 = 0; for (int dim = 0; dim < ndim; ++dim, ++acopy, ++scopy) { @@ -86,10 +108,16 @@ if (cen == 1 || r2 < mindist[obs]) { mindist[obs] = r2; } } +#ifndef KMEANS_CUSTOM_PARALLEL } +#else + } + }, nthreads); +#endif + } else { counter = nobs; } cumulative[0] = mindist[0];