ext/hnswlib/src/hnswalg.h in hnswlib-0.6.2 vs ext/hnswlib/src/hnswalg.h in hnswlib-0.7.0

- old
+ new

@@ -8,1201 +8,1265 @@ #include <assert.h> #include <unordered_set> #include <list> namespace hnswlib { - typedef unsigned int tableint; - typedef unsigned int linklistsizeint; +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; - template<typename dist_t> - class HierarchicalNSW : public AlgorithmInterface<dist_t> { - public: - static const tableint max_update_element_locks = 65536; - HierarchicalNSW() : visited_list_pool_(nullptr), data_level0_memory_(nullptr), linkLists_(nullptr), cur_element_count(0) { } - HierarchicalNSW(SpaceInterface<dist_t> *s) { - } +template<typename dist_t> +class HierarchicalNSW : public AlgorithmInterface<dist_t> { + public: + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; + static const unsigned char DELETE_MARK = 0x01; - HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { - loadIndex(location, s, max_elements); - } + size_t max_elements_{0}; + mutable std::atomic<size_t> cur_element_count{0}; // current number of elements + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; + size_t ef_{ 0 }; - HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { - max_elements_ = max_elements; + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; - num_deleted_ = 0; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - M_ = M; - maxM_ = M_; - maxM0_ = M_ * 2; - ef_construction_ = std::max(ef_construction,M_); - ef_ = 10; + VisitedListPool *visited_list_pool_{nullptr}; - level_generator_.seed(random_seed); - update_probability_generator_.seed(random_seed + 1); + // Locks operations with element by label value + mutable std::vector<std::mutex> label_op_locks_; - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); - offsetData_ = size_links_level0_; - label_offset_ = size_links_level0_ + data_size_; - offsetLevel0_ = 0; + std::mutex global; + std::vector<std::mutex> link_list_locks_; - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); + tableint enterpoint_node_{0}; - cur_element_count = 0; + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; - visited_list_pool_ = new VisitedListPool(1, max_elements); + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; + std::vector<int> element_levels_; // keeps level of each element - //initializations for special treatment of the first node - enterpoint_node_ = -1; - maxlevel_ = -1; + size_t data_size_{0}; - linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; - } + DISTFUNC<dist_t> fstdistfunc_; + void *dist_func_param_{nullptr}; - struct CompareByFirst { - constexpr bool operator()(std::pair<dist_t, tableint> const &a, - std::pair<dist_t, tableint> const &b) const noexcept { - return a.first < b.first; - } - }; + mutable std::mutex label_lookup_lock; // lock for label_lookup_ + std::unordered_map<labeltype, tableint> label_lookup_; - ~HierarchicalNSW() { - if (data_level0_memory_) free(data_level0_memory_); - if (linkLists_) { - for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) - if (linkLists_[i]) free(linkLists_[i]); - } - free(linkLists_); - } - if (visited_list_pool_) delete visited_list_pool_; - } + std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; - size_t max_elements_; - size_t cur_element_count; - size_t size_data_per_element_; - size_t size_links_per_element_; - size_t num_deleted_; + mutable std::atomic<long> metric_distance_computations{0}; + mutable std::atomic<long> metric_hops{0}; - size_t M_; - size_t maxM_; - size_t maxM0_; - size_t ef_construction_; + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions - double mult_, revSize_; - int maxlevel_; + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements + HierarchicalNSW() { } - VisitedListPool *visited_list_pool_; - std::mutex cur_element_count_guard_; + HierarchicalNSW(SpaceInterface<dist_t> *s) { + } - std::vector<std::mutex> link_list_locks_; - // Locks to prevent race condition during update/insert of an element at same time. - // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. - std::vector<std::mutex> link_list_update_locks_; - tableint enterpoint_node_; + HierarchicalNSW( + SpaceInterface<dist_t> *s, + const std::string &location, + bool nmslib = false, + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { + loadIndex(location, s, max_elements); + } - size_t size_links_level0_; - size_t offsetData_, offsetLevel0_; - char *data_level0_memory_; - char **linkLists_; - std::vector<int> element_levels_; + HierarchicalNSW( + SpaceInterface<dist_t> *s, + size_t max_elements, + size_t M = 16, + size_t ef_construction = 200, + size_t random_seed = 100, + bool allow_replace_deleted = false) + : link_list_locks_(max_elements), + label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { + max_elements_ = max_elements; + num_deleted_ = 0; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction, M_); + ef_ = 10; - size_t data_size_; + level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); - size_t label_offset_; - DISTFUNC<dist_t> fstdistfunc_; - void *dist_func_param_; - std::unordered_map<labeltype, tableint> label_lookup_; + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; - std::default_random_engine level_generator_; - std::default_random_engine update_probability_generator_; + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); - inline labeltype getExternalLabel(tableint internal_id) const { - labeltype return_label; - memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); - return return_label; - } + cur_element_count = 0; - inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); - } + visited_list_pool_ = new VisitedListPool(1, max_elements); - inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); - } + // initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; - inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + + ~HierarchicalNSW() { + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); } + free(linkLists_); + delete visited_list_pool_; + } - int getRandomLevel(double reverse_size) { - std::uniform_real_distribution<double> distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; - return (int) r; + + struct CompareByFirst { + constexpr bool operator()(std::pair<dist_t, tableint> const& a, + std::pair<dist_t, tableint> const& b) const noexcept { + return a.first < b.first; } + }; - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; + void setEf(size_t ef) { + ef_ = ef; + } - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet; - dist_t lowerBound; - if (!isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - top_candidates.emplace(dist, ep_id); - lowerBound = dist; - candidateSet.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits<dist_t>::max(); - candidateSet.emplace(-lowerBound, ep_id); - } - visited_array[ep_id] = visited_array_tag; + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } - while (!candidateSet.empty()) { - std::pair<dist_t, tableint> curr_el_pair = candidateSet.top(); - if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { - break; - } - candidateSet.pop(); - tableint curNodeNum = curr_el_pair.second; + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } - std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]); - int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } + + + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } + + + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } + + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution<double> distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + size_t getMaxElements() { + return max_elements_; + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + size_t getDeletedCount() { + return num_deleted_; + } + + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits<dist_t>::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair<dist_t, tableint> curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]); + + int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); - tableint *datal = (tableint *) (data + 1); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); #endif - for (size_t j = 0; j < size; j++) { - tableint candidate_id = *(datal + j); + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); #endif - if (visited_array[candidate_id] == visited_array_tag) continue; - visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { - candidateSet.emplace(-dist1, candidate_id); + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); #endif - if (!isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist1, candidate_id); + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } - visited_list_pool_->releaseVisitedList(vl); - - return top_candidates; } + visited_list_pool_->releaseVisitedList(vl); - mutable std::atomic<long> metric_distance_computations; - mutable std::atomic<long> metric_hops; + return top_candidates; + } - template <bool has_deletions, bool collect_metrics=false> - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; + template <bool has_deletions, bool collect_metrics = false> + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; - dist_t lowerBound; - if (!has_deletions || !isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits<dist_t>::max(); - candidate_set.emplace(-lowerBound, ep_id); - } + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; - visited_array[ep_id] = visited_array_tag; + dist_t lowerBound; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits<dist_t>::max(); + candidate_set.emplace(-lowerBound, ep_id); + } - while (!candidate_set.empty()) { + visited_array[ep_id] = visited_array_tag; - std::pair<dist_t, tableint> current_node_pair = candidate_set.top(); + while (!candidate_set.empty()) { + std::pair<dist_t, tableint> current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) { - break; - } - candidate_set.pop(); + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + break; + } + candidate_set.pop(); - tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); - if(collect_metrics){ - metric_hops++; - metric_distance_computations+=size; - } + if (collect_metrics) { + metric_hops++; + metric_distance_computations+=size; + } #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); - _mm_prefetch((char *) (data + 2), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); #endif - for (size_t j = 1; j <= size; j++) { - int candidate_id = *(data + j); + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0);//////////// + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// #endif - if (!(visited_array[candidate_id] == visited_array_tag)) { + if (!(visited_array[candidate_id] == visited_array_tag)) { + visited_array[candidate_id] = visited_array_tag; - visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - - if (top_candidates.size() < ef || lowerBound > dist) { - candidate_set.emplace(-dist, candidate_id); + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_,/////////// - _MM_HINT_T0);//////////////////////// + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// #endif - if (!has_deletions || !isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist, candidate_id); + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + top_candidates.emplace(dist, candidate_id); - if (top_candidates.size() > ef) - top_candidates.pop(); + if (top_candidates.size() > ef) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } } + } - visited_list_pool_->releaseVisitedList(vl); - return top_candidates; + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } + + + void getNeighborsByHeuristic2( + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; } - void getNeighborsByHeuristic2( - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, - const size_t M) { - if (top_candidates.size() < M) { - return; - } + std::priority_queue<std::pair<dist_t, tableint>> queue_closest; + std::vector<std::pair<dist_t, tableint>> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } - std::priority_queue<std::pair<dist_t, tableint>> queue_closest; - std::vector<std::pair<dist_t, tableint>> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair<dist_t, tableint> curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; - while (queue_closest.size()) { - if (return_list.size() >= M) + for (std::pair<dist_t, tableint> second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_); + if (curdist < dist_to_query) { + good = false; break; - std::pair<dist_t, tableint> curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - - for (std::pair<dist_t, tableint> second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), - dist_func_param_);; - if (curdist < dist_to_query) { - good = false; - break; - } } - if (good) { - return_list.push_back(curent_pair); - } } - - for (std::pair<dist_t, tableint> curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); + if (good) { + return_list.push_back(curent_pair); } } + for (std::pair<dist_t, tableint> curent_pair : return_list) { + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } - linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); - }; - linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { - return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); - }; + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, - int level, bool isUpdate) { - size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); - if (top_candidates.size() > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - std::vector<tableint> selectedNeighbors; - selectedNeighbors.reserve(M_); - while (top_candidates.size() > 0) { - selectedNeighbors.push_back(top_candidates.top().second); - top_candidates.pop(); - } + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + } - tableint next_closest_entry_point = selectedNeighbors.back(); - { - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + } - if (*ll_cur && !isUpdate) { - throw std::runtime_error("The newly inserted element should have blank link list"); - } - setListCount(ll_cur,selectedNeighbors.size()); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx] && !isUpdate) - throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - data[idx] = selectedNeighbors[idx]; + tableint mutuallyConnectNewElement( + const void *data_point, + tableint cur_c, + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, + int level, + bool isUpdate) { + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - } + std::vector<tableint> selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } + + tableint next_closest_entry_point = selectedNeighbors.back(); + + { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); } + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + if (*ll_cur && !isUpdate) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur, selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx] && !isUpdate) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); - std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]); + data[idx] = selectedNeighbors[idx]; + } + } - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]); - size_t sz_link_list_other = getListCount(ll_other); + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); - if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); - if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); + size_t sz_link_list_other = getListCount(ll_other); - tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); - bool is_cur_c_present = false; - if (isUpdate) { - for (size_t j = 0; j < sz_link_list_other; j++) { - if (data[j] == cur_c) { - is_cur_c_present = true; - break; - } + tableint *data = (tableint *) (ll_other + 1); + + bool is_cur_c_present = false; + if (isUpdate) { + for (size_t j = 0; j < sz_link_list_other; j++) { + if (data[j] == cur_c) { + is_cur_c_present = true; + break; } } + } - // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. - if (!is_cur_c_present) { - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); - for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); - } + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } - getNeighborsByHeuristic2(candidates, Mcurmax); + getNeighborsByHeuristic2(candidates, Mcurmax); - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; - } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; } - if (indx >= 0) { - data[indx] = cur_c; - } */ } + if (indx >= 0) { + data[indx] = cur_c; + } */ } } - - return next_closest_entry_point; } - std::mutex global; - size_t ef_; + return next_closest_entry_point; + } - void setEf(size_t ef) { - ef_ = ef; - } + void resizeIndex(size_t new_max_elements) { + if (new_max_elements < cur_element_count) + throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) { - std::priority_queue<std::pair<dist_t, tableint >> top_candidates; - if (cur_element_count == 0) return top_candidates; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + delete visited_list_pool_; + visited_list_pool_ = new VisitedListPool(1, new_max_elements); - for (size_t level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - int *data; - data = (int *) get_linklist(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + element_levels_.resize(new_max_elements); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } - } - } - } + std::vector<std::mutex>(new_max_elements).swap(link_list_locks_); - if (num_deleted_) { - std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } - else{ - std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } + // Reallocate base layer + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; - while (top_candidates.size() > k) { - top_candidates.pop(); - } - return top_candidates; - }; + // Reallocate all other layers + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; - void resizeIndex(size_t new_max_elements){ - if (new_max_elements<cur_element_count) - throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); + max_elements_ = new_max_elements; + } - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); - element_levels_.resize(new_max_elements); + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); - std::vector<std::mutex>(new_max_elements).swap(link_list_locks_); + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - // Reallocate base layer - char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - data_level0_memory_ = data_level0_memory_new; + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } - // Reallocate all other layers - char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); - if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); - linkLists_ = linkLists_new; - max_elements_ = new_max_elements; - } + void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); - writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); - writeBinaryPOD(output, label_offset_); - writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); + // get file size: + input.seekg(0, input.end); + std::streampos total_filesize = input.tellg(); + input.seekg(0, input.beg); - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + size_t max_elements = max_elements_i; + if (max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); - } - output.close(); - } + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); - void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) { - std::ifstream input(location, std::ios::binary); + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); + auto pos = input.tellg(); - // get file size: - input.seekg(0,input.end); - std::streampos total_filesize=input.tellg(); - input.seekg(0,input.beg); + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() >= total_filesize) { + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } - readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize, input.cur); + } + } - size_t max_elements = max_elements_i; - if(max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); - readBinaryPOD(input, label_offset_); - readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); + // throw exception if it either corrupted or old index + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); + input.clear(); + /// Optional check end + input.seekg(pos, input.beg); - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - auto pos=input.tellg(); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector<std::mutex>(max_elements).swap(link_list_locks_); + std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - /// Optional - check if index is ok: + visited_list_pool_ = new VisitedListPool(1, max_elements); - input.seekg(cur_element_count * size_data_per_element_,input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if(input.tellg() < 0 || input.tellg()>=total_filesize){ - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector<int>(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)] = i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize,input.cur); - } + for (size_t i = 0; i < cur_element_count; i++) { + if (isMarkedDeleted(i)) { + num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); } + } - // throw exception if it either corrupted or old index - if(input.tellg()!=total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); + input.close(); - input.clear(); + return; + } - /// Optional check end - input.seekg(pos,input.beg); + template<typename data_t> + std::vector<data_t> getDataByLabel(labeltype label) const { + // lock all operations with element by label + std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + std::unique_lock <std::mutex> lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + char* data_ptrv = getDataByInternalId(internalId); + size_t dim = *((size_t *) dist_func_param_); + std::vector<data_t> data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector<std::mutex>(max_elements).swap(link_list_locks_); - std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ + void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector<int>(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { - label_lookup_[getExternalLabel(i)]=i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; + std::unique_lock <std::mutex> lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); - } - } + markDeletedInternal(internalId); + } - for (size_t i = 0; i < cur_element_count; i++) { - if(isMarkedDeleted(i)) - num_deleted_ += 1; + + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + void markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); } + } else { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } - input.close(); - return; + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ + void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); + + std::unique_lock <std::mutex> lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); } + tableint internalId = search->second; + lock_table.unlock(); - template<typename data_t> - std::vector<data_t> getDataByLabel(labeltype label) const - { - tableint label_c; - auto search = label_lookup_.find(label); - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); - } - label_c = search->second; + unmarkDeletedInternal(internalId); + } - char* data_ptrv = getDataByInternalId(label_c); - size_t dim = *((size_t *) dist_func_param_); - std::vector<data_t> data; - data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { - data.push_back(*data_ptr); - data_ptr += 1; - } - return data; - } - static const unsigned char DELETE_MARK = 0x01; - // static const unsigned char REUSE_MARK = 0x10; - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ - void markDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - markDeletedInternal(internalId); - } - /** - * Uses the first 8 bits of the memory for the linked list to store the mark, - * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. - * @param internalId - */ - void markDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (!isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - num_deleted_ += 1; + /* + * Remove the deleted mark of the node. + */ + void unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); } - else - { - throw std::runtime_error("The requested to delete element is already deleted"); - } + } else { + throw std::runtime_error("The requested to undelete element is not deleted"); } + } - /** - * Remove the deleted mark of the node, does NOT really change the current graph. - * @param label - */ - void unmarkDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - unmarkDeletedInternal(internalId); - } - /** - * Remove the deleted mark of the node. - * @param internalId - */ - void unmarkDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur &= ~DELETE_MARK; - num_deleted_ -= 1; - } - else - { - throw std::runtime_error("The requested to undelete element is not deleted"); - } - } + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; + return *ll_cur & DELETE_MARK; + } - /** - * Checks the first 8 bits of the memory to see if the element is marked deleted. - * @param internalId - * @return - */ - bool isMarkedDeleted(tableint internalId) const { - unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; - return *ll_cur & DELETE_MARK; - } - unsigned short int getListCount(linklistsizeint * ptr) const { - return *((unsigned short int *)ptr); + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point + */ + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } - void setListCount(linklistsizeint * ptr, unsigned short int size) const { - *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + // lock all operations with element by label + std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; } + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); + } + lock_deleted_elements.unlock(); - void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label,-1); + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock <std::mutex> lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); } + } - void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { - // update the feature vector associated with existing point with new vector - memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - int maxLevelCopy = maxlevel_; - tableint entryPointCopy = enterpoint_node_; - // If point to be updated is entry point and graph just contains single element then just return. - if (entryPointCopy == internalId && cur_element_count == 1) - return; + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - int elemLevel = element_levels_[internalId]; - std::uniform_real_distribution<float> distribution(0.0, 1.0); - for (int layer = 0; layer <= elemLevel; layer++) { - std::unordered_set<tableint> sCand; - std::unordered_set<tableint> sNeigh; - std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer); - if (listOneHop.size() == 0) - continue; + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; - sCand.insert(internalId); + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution<float> distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set<tableint> sCand; + std::unordered_set<tableint> sNeigh; + std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; - for (auto&& elOneHop : listOneHop) { - sCand.insert(elOneHop); + sCand.insert(internalId); - if (distribution(update_probability_generator_) > updateNeighborProbability) - continue; + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); - sNeigh.insert(elOneHop); + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; - std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer); - for (auto&& elTwoHop : listTwoHop) { - sCand.insert(elTwoHop); - } + sNeigh.insert(elOneHop); + + std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); } + } - for (auto&& neigh : sNeigh) { - // if (neigh == internalId) - // continue; + for (auto&& neigh : sNeigh) { + // if (neigh == internalId) + // continue; - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; - size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 - size_t elementsToKeep = std::min(ef_construction_, size); - for (auto&& cand : sCand) { - if (cand == neigh) - continue; + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; + size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 + size_t elementsToKeep = std::min(ef_construction_, size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; - dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); - if (candidates.size() < elementsToKeep) { + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); candidates.emplace(distance, cand); - } else { - if (distance < candidates.top().first) { - candidates.pop(); - candidates.emplace(distance, cand); - } } } + } - // Retrieve neighbours using heuristic and set connections. - getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); - { - std::unique_lock <std::mutex> lock(link_list_locks_[neigh]); - linklistsizeint *ll_cur; - ll_cur = get_linklist_at_level(neigh, layer); - size_t candSize = candidates.size(); - setListCount(ll_cur, candSize); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < candSize; idx++) { - data[idx] = candidates.top().second; - candidates.pop(); - } + { + std::unique_lock <std::mutex> lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + size_t candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); } } } + } - repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); - }; + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + } - void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { - tableint currObj = entryPointInternalId; - if (dataPointLevel < maxLevel) { - dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxLevel; level > dataPointLevel; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock <std::mutex> lock(link_list_locks_[currObj]); - data = get_linklist_at_level(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); + + void repairConnectionsForUpdate( + const void *dataPoint, + tableint entryPointInternalId, + tableint dataPointInternalId, + int dataPointLevel, + int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock <std::mutex> lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); #endif - for (int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) { #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); #endif - tableint cand = datal[i]; - dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } } + } - if (dataPointLevel > maxLevel) - throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); - for (int level = dataPointLevel; level >= 0; level--) { - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer( - currObj, dataPoint, level); + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates; - while (topCandidates.size() > 0) { - if (topCandidates.top().second != dataPointInternalId) - filteredTopCandidates.push(topCandidates.top()); + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); - topCandidates.pop(); - } + topCandidates.pop(); + } - // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. - // To prevent self loops, the `topCandidates` is filtered and thus can be empty. - if (filteredTopCandidates.size() > 0) { - bool epDeleted = isMarkedDeleted(entryPointInternalId); - if (epDeleted) { - filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); - if (filteredTopCandidates.size() > ef_construction_) - filteredTopCandidates.pop(); - } - - currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); } } + } - std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) { - std::unique_lock <std::mutex> lock(link_list_locks_[internalId]); - unsigned int *data = get_linklist_at_level(internalId, level); - int size = getListCount(data); - std::vector<tableint> result(size); - tableint *ll = (tableint *) (data + 1); - memcpy(result.data(), ll,size * sizeof(tableint)); - return result; - }; - tableint addPoint(const void *data_point, labeltype label, int level) { + std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock <std::mutex> lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector<tableint> result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll, size * sizeof(tableint)); + return result; + } - tableint cur_c = 0; - { - // Checking if the element with the same label already exists - // if so, updating it *instead* of creating a new element. - std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_); - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - tableint existingInternalId = search->second; - templock_curr.unlock(); - std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); - + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock <std::mutex> lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { - unmarkDeletedInternal(existingInternalId); + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); } - updatePoint(data_point, existingInternalId, 1.0); + } + lock_table.unlock(); - return existingInternalId; + if (isMarkedDeleted(existingInternalId)) { + unmarkDeletedInternal(existingInternalId); } + updatePoint(data_point, existingInternalId, 1.0); - if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); - }; + return existingInternalId; + } - cur_c = cur_element_count; - cur_element_count++; - label_lookup_[label] = cur_c; + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); } - // Take update lock to prevent race conditions on an element with insertion/update at the same time. - std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); - std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; + cur_c = cur_element_count; + cur_element_count++; + label_lookup_[label] = cur_c; + } - element_levels_[cur_c] = curlevel; + std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + element_levels_[cur_c] = curlevel; - std::unique_lock <std::mutex> templock(global); - int maxlevelcopy = maxlevel_; - if (curlevel <= maxlevelcopy) - templock.unlock(); - tableint currObj = enterpoint_node_; - tableint enterpoint_copy = enterpoint_node_; + std::unique_lock <std::mutex> templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); - // Initialisation of the data and label - memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock <std::mutex> lock(link_list_locks_[currObj]); + data = get_linklist(currObj, level); + int size = getListCount(data); - if (curlevel) { - linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); - memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); - } - - if ((signed)currObj != -1) { - - if (curlevel < maxlevelcopy) { - - dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxlevelcopy; level > curlevel; level--) { - - - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock <std::mutex> lock(link_list_locks_[currObj]); - data = get_linklist(currObj,level); - int size = getListCount(data); - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } } + } - bool epDeleted = isMarkedDeleted(enterpoint_copy); - for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { - if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, data_point, level); - if (epDeleted) { - top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - } - currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); } - - - } else { - // Do nothing for the first element - enterpoint_node_ = 0; - maxlevel_ = curlevel; - + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } - //Releasing lock for the maximum level - if (curlevel > maxlevelcopy) { - enterpoint_node_ = cur_c; - maxlevel_ = curlevel; - } - return cur_c; - }; + // Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + } - std::priority_queue<std::pair<dist_t, labeltype >> - searchKnn(const void *query_data, size_t k) const { - std::priority_queue<std::pair<dist_t, labeltype >> result; - if (cur_element_count == 0) return result; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + std::priority_queue<std::pair<dist_t, labeltype >> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + std::priority_queue<std::pair<dist_t, labeltype >> result; + if (cur_element_count == 0) return result; - for (int level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); - data = (unsigned int *) get_linklist(currObj, level); - int size = getListCount(data); - metric_hops++; - metric_distance_computations+=size; + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } + } - std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates=searchBaseLayerST<true,true>( - currObj, query_data, std::max(ef_, k)); - } - else{ - top_candidates=searchBaseLayerST<false,true>( - currObj, query_data, std::max(ef_, k)); - } + std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; + if (num_deleted_) { + top_candidates = searchBaseLayerST<true, true>( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } else { + top_candidates = searchBaseLayerST<false, true>( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } - while (top_candidates.size() > k) { - top_candidates.pop(); - } - while (top_candidates.size() > 0) { - std::pair<dist_t, tableint> rez = top_candidates.top(); - result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second))); - top_candidates.pop(); - } - return result; - }; + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair<dist_t, tableint> rez = top_candidates.top(); + result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + } - void checkIntegrity(){ - int connections_checked=0; - std::vector <int > inbound_connections_num(cur_element_count,0); - for(int i = 0;i < cur_element_count; i++){ - for(int l = 0;l <= element_levels_[i]; l++){ - linklistsizeint *ll_cur = get_linklist_at_level(i,l); - int size = getListCount(ll_cur); - tableint *data = (tableint *) (ll_cur + 1); - std::unordered_set<tableint> s; - for (int j=0; j<size; j++){ - assert(data[j] > 0); - assert(data[j] < cur_element_count); - assert (data[j] != i); - inbound_connections_num[data[j]]++; - s.insert(data[j]); - connections_checked++; - } - assert(s.size() == size); + void checkIntegrity() { + int connections_checked = 0; + std::vector <int > inbound_connections_num(cur_element_count, 0); + for (int i = 0; i < cur_element_count; i++) { + for (int l = 0; l <= element_levels_[i]; l++) { + linklistsizeint *ll_cur = get_linklist_at_level(i, l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set<tableint> s; + for (int j = 0; j < size; j++) { + assert(data[j] > 0); + assert(data[j] < cur_element_count); + assert(data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; } + assert(s.size() == size); } - if(cur_element_count > 1){ - int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; - for(int i=0; i < cur_element_count; i++){ - assert(inbound_connections_num[i] > 0); - min1=std::min(inbound_connections_num[i],min1); - max1=std::max(inbound_connections_num[i],max1); - } - std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + if (cur_element_count > 1) { + int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + for (int i=0; i < cur_element_count; i++) { + assert(inbound_connections_num[i] > 0); + min1 = std::min(inbound_connections_num[i], min1); + max1 = std::max(inbound_connections_num[i], max1); } - std::cout << "integrity ok, checked " << connections_checked << " connections\n"; - + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; } - - }; - -} + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + } +}; +} // namespace hnswlib