ext/hnswlib/src/hnswalg.h in hnswlib-0.8.1 vs ext/hnswlib/src/hnswalg.h in hnswlib-0.9.0

- old
+ new

@@ -6,10 +6,11 @@ #include <random> #include <stdlib.h> #include <assert.h> #include <unordered_set> #include <list> +#include <memory> namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; @@ -31,11 +32,11 @@ size_t ef_{ 0 }; double mult_{0.0}, revSize_{0.0}; int maxlevel_{0}; - VisitedListPool *visited_list_pool_{nullptr}; + std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr}; // Locks operations with element by label value mutable std::vector<std::mutex> label_op_locks_; std::mutex global; @@ -91,20 +92,26 @@ size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100, bool allow_replace_deleted = false) - : link_list_locks_(max_elements), - label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + link_list_locks_(max_elements), element_levels_(max_elements), allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - M_ = M; + if ( M <= 10000 ) { + M_ = M; + } else { + HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; + HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; + M_ = 10000; + } maxM_ = M_; maxM0_ = M_ * 2; ef_construction_ = std::max(ef_construction, M_); ef_ = 10; @@ -121,11 +128,11 @@ if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory"); cur_element_count = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements)); // initializations for special treatment of the first node enterpoint_node_ = -1; maxlevel_ = -1; @@ -137,17 +144,24 @@ revSize_ = 1.0 / mult_; } ~HierarchicalNSW() { + clear(); + } + + void clear() { free(data_level0_memory_); + data_level0_memory_ = nullptr; for (tableint i = 0; i < cur_element_count; i++) { if (element_levels_[i] > 0) free(linkLists_[i]); } free(linkLists_); - delete visited_list_pool_; + linkLists_ = nullptr; + cur_element_count = 0; + visited_list_pool_.reset(nullptr); } struct CompareByFirst { constexpr bool operator()(std::pair<dist_t, tableint> const& a, @@ -290,38 +304,59 @@ return top_candidates; } - template <bool has_deletions, bool collect_metrics = false> + // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance + template <bool bare_bone_search = true, bool collect_metrics = false> std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST( + tableint ep_id, + const void *data_point, + size_t ef, + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + if (bare_bone_search || + (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { + char* ep_data = getDataByInternalId(ep_id); + dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); + } candidate_set.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits<dist_t>::max(); candidate_set.emplace(-lowerBound, ep_id); } visited_array[ep_id] = visited_array_tag; while (!candidate_set.empty()) { std::pair<dist_t, tableint> current_node_pair = candidate_set.top(); + dist_t candidate_dist = -current_node_pair.first; - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool flag_stop_search; + if (bare_bone_search) { + flag_stop_search = candidate_dist > lowerBound; + } else { + if (stop_condition) { + flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); + } else { + flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; + } + } + if (flag_stop_search) { break; } candidate_set.pop(); tableint current_node_id = current_node_pair.second; @@ -352,23 +387,49 @@ visited_array[candidate_id] = visited_array_tag; char *currObj1 = (getDataByInternalId(candidate_id)); dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool flag_consider_candidate; + if (!bare_bone_search && stop_condition) { + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); + } else { + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; + } + + if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + offsetLevel0_, /////////// _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } - if (top_candidates.size() > ef) + bool flag_remove_extra = false; + if (!bare_bone_search && stop_condition) { + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } if (!top_candidates.empty()) lowerBound = top_candidates.top().first; } } @@ -379,12 +440,12 @@ return top_candidates; } void getNeighborsByHeuristic2( - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, - const size_t M) { + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, + const size_t M) { if (top_candidates.size() < M) { return; } std::priority_queue<std::pair<dist_t, tableint>> queue_closest; @@ -572,12 +633,11 @@ void resizeIndex(size_t new_max_elements) { if (new_max_elements < cur_element_count) throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); + visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); element_levels_.resize(new_max_elements); std::vector<std::mutex>(new_max_elements).swap(link_list_locks_); @@ -594,11 +654,37 @@ linkLists_ = linkLists_new; max_elements_ = new_max_elements; } + size_t indexFileSize() const { + size_t size = 0; + size += sizeof(offsetLevel0_); + size += sizeof(max_elements_); + size += sizeof(cur_element_count); + size += sizeof(size_data_per_element_); + size += sizeof(label_offset_); + size += sizeof(offsetData_); + size += sizeof(maxlevel_); + size += sizeof(enterpoint_node_); + size += sizeof(maxM_); + size += sizeof(maxM0_); + size += sizeof(M_); + size += sizeof(mult_); + size += sizeof(ef_construction_); + + size += cur_element_count * size_data_per_element_; + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + size += sizeof(linkListSize); + size += linkListSize; + } + return size; + } + void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); std::streampos position; writeBinaryPOD(output, offsetLevel0_); @@ -632,10 +718,11 @@ std::ifstream input(location, std::ios::binary); if (!input.is_open()) throw std::runtime_error("Cannot open file"); + clear(); // get file size: input.seekg(0, input.end); std::streampos total_filesize = input.tellg(); input.seekg(0, input.beg); @@ -697,11 +784,11 @@ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector<std::mutex>(max_elements).swap(link_list_locks_); std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_.reset(new VisitedListPool(1, max_elements)); linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); element_levels_ = std::vector<int>(max_elements); @@ -751,11 +838,11 @@ char* data_ptrv = getDataByInternalId(internalId); size_t dim = *((size_t *) dist_func_param_); std::vector<data_t> data; data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { + for (size_t i = 0; i < dim; i++) { data.push_back(*data_ptr); data_ptr += 1; } return data; } @@ -1215,15 +1302,16 @@ } } } std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates = searchBaseLayerST<true, true>( + bool bare_bone_search = !num_deleted_ && !isIdAllowed; + if (bare_bone_search) { + top_candidates = searchBaseLayerST<true>( currObj, query_data, std::max(ef_, k), isIdAllowed); } else { - top_candidates = searchBaseLayerST<false, true>( + top_candidates = searchBaseLayerST<false>( currObj, query_data, std::max(ef_, k), isIdAllowed); } while (top_candidates.size() > k) { top_candidates.pop(); @@ -1235,20 +1323,73 @@ } return result; } + std::vector<std::pair<dist_t, labeltype >> + searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition<dist_t>& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector<std::pair<dist_t, labeltype >> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; + top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition); + + size_t sz = top_candidates.size(); + result.resize(sz); + while (!top_candidates.empty()) { + result[--sz] = top_candidates.top(); + top_candidates.pop(); + } + + stop_condition.filter_results(result); + + return result; + } + + void checkIntegrity() { int connections_checked = 0; std::vector <int > inbound_connections_num(cur_element_count, 0); for (int i = 0; i < cur_element_count; i++) { for (int l = 0; l <= element_levels_[i]; l++) { linklistsizeint *ll_cur = get_linklist_at_level(i, l); int size = getListCount(ll_cur); tableint *data = (tableint *) (ll_cur + 1); std::unordered_set<tableint> s; for (int j = 0; j < size; j++) { - assert(data[j] > 0); assert(data[j] < cur_element_count); assert(data[j] != i); inbound_connections_num[data[j]]++; s.insert(data[j]); connections_checked++;