// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_ #define DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_ #include "find_k_nearest_neighbors_lsh_abstract.h" #include "../threads.h" #include "../lsh/hashes.h" #include <vector> #include <queue> #include "sample_pair.h" #include "edge_list_graphs.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { struct compare_sample_pair_with_distance { inline bool operator() (const sample_pair& a, const sample_pair& b) const { return a.distance() < b.distance(); } }; template < typename vector_type, typename hash_function_type > class hash_block { public: hash_block( const vector_type& samples_, const hash_function_type& hash_funct_, std::vector<typename hash_function_type::result_type>& hashes_ ) : samples(samples_), hash_funct(hash_funct_), hashes(hashes_) {} void operator() (long i) const { hashes[i] = hash_funct(samples[i]); } const vector_type& samples; const hash_function_type& hash_funct; std::vector<typename hash_function_type::result_type>& hashes; }; template < typename vector_type, typename distance_function_type, typename hash_function_type, typename alloc > class scan_find_k_nearest_neighbors_lsh { public: scan_find_k_nearest_neighbors_lsh ( const vector_type& samples_, const distance_function_type& dist_funct_, const hash_function_type& hash_funct_, const unsigned long k_, std::vector<sample_pair, alloc>& edges_, const unsigned long k_oversample_, const std::vector<typename hash_function_type::result_type>& hashes_ ) : samples(samples_), dist_funct(dist_funct_), hash_funct(hash_funct_), k(k_), edges(edges_), k_oversample(k_oversample_), hashes(hashes_) { edges.clear(); edges.reserve(samples.size()*k/2); } mutex m; const vector_type& samples; const distance_function_type& dist_funct; const hash_function_type& hash_funct; const unsigned long k; std::vector<sample_pair, alloc>& edges; const unsigned long k_oversample; const std::vector<typename hash_function_type::result_type>& hashes; void operator() (unsigned long i) const { const unsigned long k_hash = k*k_oversample; std::priority_queue<std::pair<unsigned long, unsigned long> > best_hashes; std::priority_queue<sample_pair, std::vector<sample_pair>, dlib::impl::compare_sample_pair_with_distance> best_samples; unsigned long worst_distance = std::numeric_limits<unsigned long>::max(); // scan over the hashes and find the best matches for hashes[i] for (unsigned long j = 0; j < hashes.size(); ++j) { if (i == j) continue; const unsigned long dist = hash_funct.distance(hashes[i], hashes[j]); if (dist < worst_distance || best_hashes.size() < k_hash) { if (best_hashes.size() >= k_hash) best_hashes.pop(); best_hashes.push(std::make_pair(dist, j)); worst_distance = best_hashes.top().first; } } // Now figure out which of the best_hashes are actually the k best matches // according to dist_funct() while (best_hashes.size() != 0) { const unsigned long j = best_hashes.top().second; best_hashes.pop(); const double dist = dist_funct(samples[i], samples[j]); if (dist < std::numeric_limits<double>::infinity()) { if (best_samples.size() >= k) best_samples.pop(); best_samples.push(sample_pair(i,j,dist)); } } // Finally, now put the k best matches according to dist_funct() into edges auto_mutex lock(m); while (best_samples.size() != 0) { edges.push_back(best_samples.top()); best_samples.pop(); } } }; } // ---------------------------------------------------------------------------------------- template < typename vector_type, typename hash_function_type > void hash_samples ( const vector_type& samples, const hash_function_type& hash_funct, const unsigned long num_threads, std::vector<typename hash_function_type::result_type>& hashes ) { hashes.resize(samples.size()); typedef impl::hash_block<vector_type,hash_function_type> block_type; block_type temp(samples, hash_funct, hashes); parallel_for(num_threads, 0, samples.size(), temp); } // ---------------------------------------------------------------------------------------- template < typename vector_type, typename distance_function_type, typename hash_function_type, typename alloc > void find_k_nearest_neighbors_lsh ( const vector_type& samples, const distance_function_type& dist_funct, const hash_function_type& hash_funct, const unsigned long k, const unsigned long num_threads, std::vector<sample_pair, alloc>& edges, const unsigned long k_oversample = 20 ) { // make sure requires clause is not broken DLIB_ASSERT(k > 0 && k_oversample > 0, "\t void find_k_nearest_neighbors_lsh()" << "\n\t Invalid inputs were given to this function." << "\n\t samples.size(): " << samples.size() << "\n\t k: " << k << "\n\t k_oversample: " << k_oversample ); edges.clear(); if (samples.size() <= 1) { return; } typedef typename hash_function_type::result_type hash_type; std::vector<hash_type> hashes; hash_samples(samples, hash_funct, num_threads, hashes); typedef impl::scan_find_k_nearest_neighbors_lsh<vector_type, distance_function_type,hash_function_type,alloc> scan_type; scan_type temp(samples, dist_funct, hash_funct, k, edges, k_oversample, hashes); parallel_for(num_threads, 0, hashes.size(), temp); remove_duplicate_edges(edges); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_FIND_K_NEAREST_NEIGHBOrS_LSH_Hh_