vendor/kmeans/MiniBatch.hpp in umappp-0.1.6 vs vendor/kmeans/MiniBatch.hpp in umappp-0.2.0

- old
+ new

@@ -51,33 +51,38 @@ /** * @brief Default parameter values for `MiniBatch`. */ struct Defaults { /** - * See `MiniBatch::set_max_iterations()`. + * See `set_max_iterations()` for more details. */ static constexpr int max_iterations = 100; /** - * See `MiniBatch::set_batch_size()`. + * See `set_batch_size()` for more details. */ static constexpr INDEX_t batch_size = 500; /** - * See `MiniBatch::set_max_change_proportion()`. + * See `set_max_change_proportion()` for more details. */ static constexpr double max_change_proportion = 0.01; /** - * See `MiniBatch::set_convergence_history()`. + * See `set_convergence_history()` for more details. */ static constexpr int convergence_history = 10; /** - * See `MiniBatch::set_seed()`. + * See `set_seed()` for more details. */ static constexpr uint64_t seed = 1234567890; + + /** + * See `set_num_threads()` for more details. + */ + static constexpr int num_threads = 1; }; private: int maxiter = Defaults::max_iterations; @@ -86,10 +91,12 @@ int history = Defaults::convergence_history; double max_change = Defaults::max_change_proportion; uint64_t seed = Defaults::seed; + + int nthreads = Defaults::num_threads; public: /** * @param i Maximum number of iterations. * More iterations increase the opportunity for convergence at the cost of more computational time. * @@ -141,10 +148,20 @@ MiniBatch& set_seed(uint64_t s = Defaults::seed) { seed = s; return *this; } + /** + * @param n Number of threads to use. + * + * @return A reference to this `MiniBatch` object. + */ + MiniBatch& set_num_threads(int n = Defaults::num_threads) { + nthreads = n; + return *this; + } + public: /** * @param ndim Number of dimensions. * @param nobs Number of observations. * @param[in] data Pointer to a `ndim`-by-`nobs` array where columns are observations and rows are dimensions. @@ -181,14 +198,26 @@ previous[o] = clusters[o]; } } QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers); - #pragma omp parallel for - for (size_t i = 0; i < chosen.size(); ++i) { + size_t nchosen = chosen.size(); + +#ifndef KMEANS_CUSTOM_PARALLEL + #pragma omp parallel for num_threads(nthreads) + for (size_t i = 0; i < nchosen; ++i) { +#else + KMEANS_CUSTOM_PARALLEL(nchosen, [&](size_t first, size_t last) -> void { + for (size_t i = first; i < last; ++i) { +#endif clusters[chosen[i]] = index.find(data + chosen[i] * ndim); +#ifndef KMEANS_CUSTOM_PARALLEL } +#else + } + }, nthreads); +#endif // Updating the means for each cluster. for (auto o : chosen) { const auto c = clusters[o]; auto& n = total_sampled[c]; @@ -234,13 +263,24 @@ status = 2; } // Run through all observations to make sure they have the latest cluster assignments. QuickSearch<DATA_t, CLUSTER_t> index(ndim, ncenters, centers); - #pragma omp parallel for + +#ifndef KMEANS_CUSTOM_PARALLEL + #pragma omp parallel for num_threads(nthreads) for (INDEX_t o = 0; o < nobs; ++o) { +#else + KMEANS_CUSTOM_PARALLEL(nobs, [&](INDEX_t first, INDEX_t last) -> void { + for (INDEX_t o = first; o < last; ++o) { +#endif clusters[o] = index.find(data + o * ndim); +#ifndef KMEANS_CUSTOM_PARALLEL } +#else + } + }, nthreads); +#endif std::fill(total_sampled.begin(), total_sampled.end(), 0); for (INDEX_t o = 0; o < nobs; ++o) { ++total_sampled[clusters[o]]; }