vendor/knncolle/utils/find_nearest_neighbors.hpp in umappp-0.1.6 vs vendor/knncolle/utils/find_nearest_neighbors.hpp in umappp-0.2.0
- old
+ new
@@ -34,21 +34,22 @@
* @tparam InputDISTANCE_t Floating point type for the distances in the input index.
* @tparam QUERY_t Floating point type for the query data in the input index.
*
* @param ptr Pointer to a `Base` index.
* @param k Number of nearest neighbors.
+ * @param nthreads Number of threads to use.
*
* @return A `NeighborList` of length equal to the number of observations in `ptr->nobs()`.
* Each entry contains the `k` nearest neighbors for each observation, sorted by increasing distance.
*/
template<typename INDEX_t = int, typename DISTANCE_t = double, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
-NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
+NeighborList<INDEX_t, DISTANCE_t> find_nearest_neighbors(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
auto n = ptr->nobs();
NeighborList<INDEX_t, DISTANCE_t> output(n);
#ifndef KNNCOLLE_CUSTOM_PARALLEL
- #pragma omp parallel for
+ #pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < n; ++i) {
#else
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
for (size_t i = first; i < last; ++i) {
#endif
@@ -60,11 +61,11 @@
output[i].emplace_back(x.first, x.second);
}
}
}
#ifdef KNNCOLLE_CUSTOM_PARALLEL
- });
+ }, nthreads);
#endif
return output;
}
@@ -77,21 +78,22 @@
* @tparam InputDISTANCE_t Floating point type for the distances in the input index.
* @tparam QUERY_t Floating point type for the query data in the input index.
*
* @param ptr Pointer to a `Base` index.
* @param k Number of nearest neighbors.
+ * @param nthreads Number of threads to use.
*
* @return A vector of vectors of length equal to the number of observations in `ptr->nobs()`.
* Each vector contains the indices of the `k` nearest neighbors for each observation, sorted by increasing distance.
*/
template<typename INDEX_t = int, typename InputINDEX_t, typename InputDISTANCE_t, typename InputQUERY_t>
-std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k) {
+std::vector<std::vector<INDEX_t> > find_nearest_neighbors_index_only(const Base<InputINDEX_t, InputDISTANCE_t, InputQUERY_t>* ptr, int k, int nthreads) {
auto n = ptr->nobs();
std::vector<std::vector<INDEX_t> > output(n);
#ifndef KNNCOLLE_CUSTOM_PARALLEL
- #pragma omp parallel for
+ #pragma omp parallel for num_threads(nthreads)
for (size_t i = 0; i < n; ++i) {
#else
KNNCOLLE_CUSTOM_PARALLEL(n, [&](size_t first, size_t last) -> void {
for (size_t i = first; i < last; ++i) {
#endif
@@ -99,10 +101,10 @@
for (const auto& x : current) {
output[i].push_back(x.first);
}
}
#ifdef KNNCOLLE_CUSTOM_PARALLEL
- });
+ }, nthreads);
#endif
return output;
}