ext/hnswlib/src/hnswalg.h in hnswlib-0.6.2 vs ext/hnswlib/src/hnswalg.h in hnswlib-0.7.0
- old
+ new
@@ -8,1201 +8,1265 @@
#include <assert.h>
#include <unordered_set>
#include <list>
namespace hnswlib {
- typedef unsigned int tableint;
- typedef unsigned int linklistsizeint;
+typedef unsigned int tableint;
+typedef unsigned int linklistsizeint;
- template<typename dist_t>
- class HierarchicalNSW : public AlgorithmInterface<dist_t> {
- public:
- static const tableint max_update_element_locks = 65536;
- HierarchicalNSW() : visited_list_pool_(nullptr), data_level0_memory_(nullptr), linkLists_(nullptr), cur_element_count(0) { }
- HierarchicalNSW(SpaceInterface<dist_t> *s) {
- }
+template<typename dist_t>
+class HierarchicalNSW : public AlgorithmInterface<dist_t> {
+ public:
+ static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
+ static const unsigned char DELETE_MARK = 0x01;
- HierarchicalNSW(SpaceInterface<dist_t> *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
- loadIndex(location, s, max_elements);
- }
+ size_t max_elements_{0};
+ mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
+ size_t size_data_per_element_{0};
+ size_t size_links_per_element_{0};
+ mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
+ size_t M_{0};
+ size_t maxM_{0};
+ size_t maxM0_{0};
+ size_t ef_construction_{0};
+ size_t ef_{ 0 };
- HierarchicalNSW(SpaceInterface<dist_t> *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
- link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) {
- max_elements_ = max_elements;
+ double mult_{0.0}, revSize_{0.0};
+ int maxlevel_{0};
- num_deleted_ = 0;
- data_size_ = s->get_data_size();
- fstdistfunc_ = s->get_dist_func();
- dist_func_param_ = s->get_dist_func_param();
- M_ = M;
- maxM_ = M_;
- maxM0_ = M_ * 2;
- ef_construction_ = std::max(ef_construction,M_);
- ef_ = 10;
+ VisitedListPool *visited_list_pool_{nullptr};
- level_generator_.seed(random_seed);
- update_probability_generator_.seed(random_seed + 1);
+ // Locks operations with element by label value
+ mutable std::vector<std::mutex> label_op_locks_;
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
- size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
- offsetData_ = size_links_level0_;
- label_offset_ = size_links_level0_ + data_size_;
- offsetLevel0_ = 0;
+ std::mutex global;
+ std::vector<std::mutex> link_list_locks_;
- data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
- if (data_level0_memory_ == nullptr)
- throw std::runtime_error("Not enough memory");
+ tableint enterpoint_node_{0};
- cur_element_count = 0;
+ size_t size_links_level0_{0};
+ size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };
- visited_list_pool_ = new VisitedListPool(1, max_elements);
+ char *data_level0_memory_{nullptr};
+ char **linkLists_{nullptr};
+ std::vector<int> element_levels_; // keeps level of each element
- //initializations for special treatment of the first node
- enterpoint_node_ = -1;
- maxlevel_ = -1;
+ size_t data_size_{0};
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
- if (linkLists_ == nullptr)
- throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
- mult_ = 1 / log(1.0 * M_);
- revSize_ = 1.0 / mult_;
- }
+ DISTFUNC<dist_t> fstdistfunc_;
+ void *dist_func_param_{nullptr};
- struct CompareByFirst {
- constexpr bool operator()(std::pair<dist_t, tableint> const &a,
- std::pair<dist_t, tableint> const &b) const noexcept {
- return a.first < b.first;
- }
- };
+ mutable std::mutex label_lookup_lock; // lock for label_lookup_
+ std::unordered_map<labeltype, tableint> label_lookup_;
- ~HierarchicalNSW() {
- if (data_level0_memory_) free(data_level0_memory_);
- if (linkLists_) {
- for (tableint i = 0; i < cur_element_count; i++) {
- if (element_levels_[i] > 0)
- if (linkLists_[i]) free(linkLists_[i]);
- }
- free(linkLists_);
- }
- if (visited_list_pool_) delete visited_list_pool_;
- }
+ std::default_random_engine level_generator_;
+ std::default_random_engine update_probability_generator_;
- size_t max_elements_;
- size_t cur_element_count;
- size_t size_data_per_element_;
- size_t size_links_per_element_;
- size_t num_deleted_;
+ mutable std::atomic<long> metric_distance_computations{0};
+ mutable std::atomic<long> metric_hops{0};
- size_t M_;
- size_t maxM_;
- size_t maxM0_;
- size_t ef_construction_;
+ bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
- double mult_, revSize_;
- int maxlevel_;
+ std::mutex deleted_elements_lock; // lock for deleted_elements
+ std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
+ HierarchicalNSW() { }
- VisitedListPool *visited_list_pool_;
- std::mutex cur_element_count_guard_;
+ HierarchicalNSW(SpaceInterface<dist_t> *s) {
+ }
- std::vector<std::mutex> link_list_locks_;
- // Locks to prevent race condition during update/insert of an element at same time.
- // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
- std::vector<std::mutex> link_list_update_locks_;
- tableint enterpoint_node_;
+ HierarchicalNSW(
+ SpaceInterface<dist_t> *s,
+ const std::string &location,
+ bool nmslib = false,
+ size_t max_elements = 0,
+ bool allow_replace_deleted = false)
+ : allow_replace_deleted_(allow_replace_deleted) {
+ loadIndex(location, s, max_elements);
+ }
- size_t size_links_level0_;
- size_t offsetData_, offsetLevel0_;
- char *data_level0_memory_;
- char **linkLists_;
- std::vector<int> element_levels_;
+ HierarchicalNSW(
+ SpaceInterface<dist_t> *s,
+ size_t max_elements,
+ size_t M = 16,
+ size_t ef_construction = 200,
+ size_t random_seed = 100,
+ bool allow_replace_deleted = false)
+ : link_list_locks_(max_elements),
+ label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
+ element_levels_(max_elements),
+ allow_replace_deleted_(allow_replace_deleted) {
+ max_elements_ = max_elements;
+ num_deleted_ = 0;
+ data_size_ = s->get_data_size();
+ fstdistfunc_ = s->get_dist_func();
+ dist_func_param_ = s->get_dist_func_param();
+ M_ = M;
+ maxM_ = M_;
+ maxM0_ = M_ * 2;
+ ef_construction_ = std::max(ef_construction, M_);
+ ef_ = 10;
- size_t data_size_;
+ level_generator_.seed(random_seed);
+ update_probability_generator_.seed(random_seed + 1);
- size_t label_offset_;
- DISTFUNC<dist_t> fstdistfunc_;
- void *dist_func_param_;
- std::unordered_map<labeltype, tableint> label_lookup_;
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
+ size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
+ offsetData_ = size_links_level0_;
+ label_offset_ = size_links_level0_ + data_size_;
+ offsetLevel0_ = 0;
- std::default_random_engine level_generator_;
- std::default_random_engine update_probability_generator_;
+ data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
+ if (data_level0_memory_ == nullptr)
+ throw std::runtime_error("Not enough memory");
- inline labeltype getExternalLabel(tableint internal_id) const {
- labeltype return_label;
- memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
- return return_label;
- }
+ cur_element_count = 0;
- inline void setExternalLabel(tableint internal_id, labeltype label) const {
- memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
- }
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
- inline labeltype *getExternalLabeLp(tableint internal_id) const {
- return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
- }
+ // initializations for special treatment of the first node
+ enterpoint_node_ = -1;
+ maxlevel_ = -1;
- inline char *getDataByInternalId(tableint internal_id) const {
- return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
+ if (linkLists_ == nullptr)
+ throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
+ mult_ = 1 / log(1.0 * M_);
+ revSize_ = 1.0 / mult_;
+ }
+
+
+ ~HierarchicalNSW() {
+ free(data_level0_memory_);
+ for (tableint i = 0; i < cur_element_count; i++) {
+ if (element_levels_[i] > 0)
+ free(linkLists_[i]);
}
+ free(linkLists_);
+ delete visited_list_pool_;
+ }
- int getRandomLevel(double reverse_size) {
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
- double r = -log(distribution(level_generator_)) * reverse_size;
- return (int) r;
+
+ struct CompareByFirst {
+ constexpr bool operator()(std::pair<dist_t, tableint> const& a,
+ std::pair<dist_t, tableint> const& b) const noexcept {
+ return a.first < b.first;
}
+ };
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
- searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
- vl_type *visited_array = vl->mass;
- vl_type visited_array_tag = vl->curV;
+ void setEf(size_t ef) {
+ ef_ = ef;
+ }
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
- dist_t lowerBound;
- if (!isMarkedDeleted(ep_id)) {
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
- top_candidates.emplace(dist, ep_id);
- lowerBound = dist;
- candidateSet.emplace(-dist, ep_id);
- } else {
- lowerBound = std::numeric_limits<dist_t>::max();
- candidateSet.emplace(-lowerBound, ep_id);
- }
- visited_array[ep_id] = visited_array_tag;
+ inline std::mutex& getLabelOpMutex(labeltype label) const {
+ // calculate hash
+ size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
+ return label_op_locks_[lock_id];
+ }
- while (!candidateSet.empty()) {
- std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
- if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
- break;
- }
- candidateSet.pop();
- tableint curNodeNum = curr_el_pair.second;
+ inline labeltype getExternalLabel(tableint internal_id) const {
+ labeltype return_label;
+ memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
+ return return_label;
+ }
- std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
- int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
- if (layer == 0) {
- data = (int*)get_linklist0(curNodeNum);
- } else {
- data = (int*)get_linklist(curNodeNum, layer);
+ inline void setExternalLabel(tableint internal_id, labeltype label) const {
+ memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
+ }
+
+
+ inline labeltype *getExternalLabeLp(tableint internal_id) const {
+ return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
+ }
+
+
+ inline char *getDataByInternalId(tableint internal_id) const {
+ return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
+ }
+
+
+ int getRandomLevel(double reverse_size) {
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
+ double r = -log(distribution(level_generator_)) * reverse_size;
+ return (int) r;
+ }
+
+ size_t getMaxElements() {
+ return max_elements_;
+ }
+
+ size_t getCurrentElementCount() {
+ return cur_element_count;
+ }
+
+ size_t getDeletedCount() {
+ return num_deleted_;
+ }
+
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
+ searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
+ vl_type *visited_array = vl->mass;
+ vl_type visited_array_tag = vl->curV;
+
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;
+
+ dist_t lowerBound;
+ if (!isMarkedDeleted(ep_id)) {
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
+ top_candidates.emplace(dist, ep_id);
+ lowerBound = dist;
+ candidateSet.emplace(-dist, ep_id);
+ } else {
+ lowerBound = std::numeric_limits<dist_t>::max();
+ candidateSet.emplace(-lowerBound, ep_id);
+ }
+ visited_array[ep_id] = visited_array_tag;
+
+ while (!candidateSet.empty()) {
+ std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
+ if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
+ break;
+ }
+ candidateSet.pop();
+
+ tableint curNodeNum = curr_el_pair.second;
+
+ std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
+
+ int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
+ if (layer == 0) {
+ data = (int*)get_linklist0(curNodeNum);
+ } else {
+ data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
- }
- size_t size = getListCount((linklistsizeint*)data);
- tableint *datal = (tableint *) (data + 1);
+ }
+ size_t size = getListCount((linklistsizeint*)data);
+ tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
- _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
#endif
- for (size_t j = 0; j < size; j++) {
- tableint candidate_id = *(datal + j);
+ for (size_t j = 0; j < size; j++) {
+ tableint candidate_id = *(datal + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
- _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
- _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
+ _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
#endif
- if (visited_array[candidate_id] == visited_array_tag) continue;
- visited_array[candidate_id] = visited_array_tag;
- char *currObj1 = (getDataByInternalId(candidate_id));
+ if (visited_array[candidate_id] == visited_array_tag) continue;
+ visited_array[candidate_id] = visited_array_tag;
+ char *currObj1 = (getDataByInternalId(candidate_id));
- dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
- if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
- candidateSet.emplace(-dist1, candidate_id);
+ dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
+ if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
+ candidateSet.emplace(-dist1, candidate_id);
#ifdef USE_SSE
- _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif
- if (!isMarkedDeleted(candidate_id))
- top_candidates.emplace(dist1, candidate_id);
+ if (!isMarkedDeleted(candidate_id))
+ top_candidates.emplace(dist1, candidate_id);
- if (top_candidates.size() > ef_construction_)
- top_candidates.pop();
+ if (top_candidates.size() > ef_construction_)
+ top_candidates.pop();
- if (!top_candidates.empty())
- lowerBound = top_candidates.top().first;
- }
+ if (!top_candidates.empty())
+ lowerBound = top_candidates.top().first;
}
}
- visited_list_pool_->releaseVisitedList(vl);
-
- return top_candidates;
}
+ visited_list_pool_->releaseVisitedList(vl);
- mutable std::atomic<long> metric_distance_computations;
- mutable std::atomic<long> metric_hops;
+ return top_candidates;
+ }
- template <bool has_deletions, bool collect_metrics=false>
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
- VisitedList *vl = visited_list_pool_->getFreeVisitedList();
- vl_type *visited_array = vl->mass;
- vl_type visited_array_tag = vl->curV;
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
+ template <bool has_deletions, bool collect_metrics = false>
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
+ searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
+ VisitedList *vl = visited_list_pool_->getFreeVisitedList();
+ vl_type *visited_array = vl->mass;
+ vl_type visited_array_tag = vl->curV;
- dist_t lowerBound;
- if (!has_deletions || !isMarkedDeleted(ep_id)) {
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
- lowerBound = dist;
- top_candidates.emplace(dist, ep_id);
- candidate_set.emplace(-dist, ep_id);
- } else {
- lowerBound = std::numeric_limits<dist_t>::max();
- candidate_set.emplace(-lowerBound, ep_id);
- }
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
- visited_array[ep_id] = visited_array_tag;
+ dist_t lowerBound;
+ if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
+ dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
+ lowerBound = dist;
+ top_candidates.emplace(dist, ep_id);
+ candidate_set.emplace(-dist, ep_id);
+ } else {
+ lowerBound = std::numeric_limits<dist_t>::max();
+ candidate_set.emplace(-lowerBound, ep_id);
+ }
- while (!candidate_set.empty()) {
+ visited_array[ep_id] = visited_array_tag;
- std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
+ while (!candidate_set.empty()) {
+ std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
- if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) {
- break;
- }
- candidate_set.pop();
+ if ((-current_node_pair.first) > lowerBound &&
+ (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
+ break;
+ }
+ candidate_set.pop();
- tableint current_node_id = current_node_pair.second;
- int *data = (int *) get_linklist0(current_node_id);
- size_t size = getListCount((linklistsizeint*)data);
+ tableint current_node_id = current_node_pair.second;
+ int *data = (int *) get_linklist0(current_node_id);
+ size_t size = getListCount((linklistsizeint*)data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
- if(collect_metrics){
- metric_hops++;
- metric_distance_computations+=size;
- }
+ if (collect_metrics) {
+ metric_hops++;
+ metric_distance_computations+=size;
+ }
#ifdef USE_SSE
- _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
- _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
- _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
- _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
+ _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
+ _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
+ _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
+ _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
#endif
- for (size_t j = 1; j <= size; j++) {
- int candidate_id = *(data + j);
+ for (size_t j = 1; j <= size; j++) {
+ int candidate_id = *(data + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
- _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
- _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
- _MM_HINT_T0);////////////
+ _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
+ _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
+ _MM_HINT_T0); ////////////
#endif
- if (!(visited_array[candidate_id] == visited_array_tag)) {
+ if (!(visited_array[candidate_id] == visited_array_tag)) {
+ visited_array[candidate_id] = visited_array_tag;
- visited_array[candidate_id] = visited_array_tag;
+ char *currObj1 = (getDataByInternalId(candidate_id));
+ dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
- char *currObj1 = (getDataByInternalId(candidate_id));
- dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
-
- if (top_candidates.size() < ef || lowerBound > dist) {
- candidate_set.emplace(-dist, candidate_id);
+ if (top_candidates.size() < ef || lowerBound > dist) {
+ candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
- _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
- offsetLevel0_,///////////
- _MM_HINT_T0);////////////////////////
+ _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
+ offsetLevel0_, ///////////
+ _MM_HINT_T0); ////////////////////////
#endif
- if (!has_deletions || !isMarkedDeleted(candidate_id))
- top_candidates.emplace(dist, candidate_id);
+ if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
+ top_candidates.emplace(dist, candidate_id);
- if (top_candidates.size() > ef)
- top_candidates.pop();
+ if (top_candidates.size() > ef)
+ top_candidates.pop();
- if (!top_candidates.empty())
- lowerBound = top_candidates.top().first;
- }
+ if (!top_candidates.empty())
+ lowerBound = top_candidates.top().first;
}
}
}
+ }
- visited_list_pool_->releaseVisitedList(vl);
- return top_candidates;
+ visited_list_pool_->releaseVisitedList(vl);
+ return top_candidates;
+ }
+
+
+ void getNeighborsByHeuristic2(
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
+ const size_t M) {
+ if (top_candidates.size() < M) {
+ return;
}
- void getNeighborsByHeuristic2(
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
- const size_t M) {
- if (top_candidates.size() < M) {
- return;
- }
+ std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
+ std::vector<std::pair<dist_t, tableint>> return_list;
+ while (top_candidates.size() > 0) {
+ queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
+ top_candidates.pop();
+ }
- std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
- std::vector<std::pair<dist_t, tableint>> return_list;
- while (top_candidates.size() > 0) {
- queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
- top_candidates.pop();
- }
+ while (queue_closest.size()) {
+ if (return_list.size() >= M)
+ break;
+ std::pair<dist_t, tableint> curent_pair = queue_closest.top();
+ dist_t dist_to_query = -curent_pair.first;
+ queue_closest.pop();
+ bool good = true;
- while (queue_closest.size()) {
- if (return_list.size() >= M)
+ for (std::pair<dist_t, tableint> second_pair : return_list) {
+ dist_t curdist =
+ fstdistfunc_(getDataByInternalId(second_pair.second),
+ getDataByInternalId(curent_pair.second),
+ dist_func_param_);
+ if (curdist < dist_to_query) {
+ good = false;
break;
- std::pair<dist_t, tableint> curent_pair = queue_closest.top();
- dist_t dist_to_query = -curent_pair.first;
- queue_closest.pop();
- bool good = true;
-
- for (std::pair<dist_t, tableint> second_pair : return_list) {
- dist_t curdist =
- fstdistfunc_(getDataByInternalId(second_pair.second),
- getDataByInternalId(curent_pair.second),
- dist_func_param_);;
- if (curdist < dist_to_query) {
- good = false;
- break;
- }
}
- if (good) {
- return_list.push_back(curent_pair);
- }
}
-
- for (std::pair<dist_t, tableint> curent_pair : return_list) {
- top_candidates.emplace(-curent_pair.first, curent_pair.second);
+ if (good) {
+ return_list.push_back(curent_pair);
}
}
+ for (std::pair<dist_t, tableint> curent_pair : return_list) {
+ top_candidates.emplace(-curent_pair.first, curent_pair.second);
+ }
+ }
- linklistsizeint *get_linklist0(tableint internal_id) const {
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
- };
- linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
- return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
- };
+ linklistsizeint *get_linklist0(tableint internal_id) const {
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
+ }
- linklistsizeint *get_linklist(tableint internal_id, int level) const {
- return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
- };
- linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
- return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
- };
+ linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
+ return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
+ }
- tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
- int level, bool isUpdate) {
- size_t Mcurmax = level ? maxM_ : maxM0_;
- getNeighborsByHeuristic2(top_candidates, M_);
- if (top_candidates.size() > M_)
- throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
- std::vector<tableint> selectedNeighbors;
- selectedNeighbors.reserve(M_);
- while (top_candidates.size() > 0) {
- selectedNeighbors.push_back(top_candidates.top().second);
- top_candidates.pop();
- }
+ linklistsizeint *get_linklist(tableint internal_id, int level) const {
+ return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
+ }
- tableint next_closest_entry_point = selectedNeighbors.back();
- {
- linklistsizeint *ll_cur;
- if (level == 0)
- ll_cur = get_linklist0(cur_c);
- else
- ll_cur = get_linklist(cur_c, level);
+ linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
+ return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
+ }
- if (*ll_cur && !isUpdate) {
- throw std::runtime_error("The newly inserted element should have blank link list");
- }
- setListCount(ll_cur,selectedNeighbors.size());
- tableint *data = (tableint *) (ll_cur + 1);
- for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
- if (data[idx] && !isUpdate)
- throw std::runtime_error("Possible memory corruption");
- if (level > element_levels_[selectedNeighbors[idx]])
- throw std::runtime_error("Trying to make a link on a non-existent level");
- data[idx] = selectedNeighbors[idx];
+ tableint mutuallyConnectNewElement(
+ const void *data_point,
+ tableint cur_c,
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
+ int level,
+ bool isUpdate) {
+ size_t Mcurmax = level ? maxM_ : maxM0_;
+ getNeighborsByHeuristic2(top_candidates, M_);
+ if (top_candidates.size() > M_)
+ throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
- }
+ std::vector<tableint> selectedNeighbors;
+ selectedNeighbors.reserve(M_);
+ while (top_candidates.size() > 0) {
+ selectedNeighbors.push_back(top_candidates.top().second);
+ top_candidates.pop();
+ }
+
+ tableint next_closest_entry_point = selectedNeighbors.back();
+
+ {
+ // lock only during the update
+ // because during the addition the lock for cur_c is already acquired
+ std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
+ if (isUpdate) {
+ lock.lock();
}
+ linklistsizeint *ll_cur;
+ if (level == 0)
+ ll_cur = get_linklist0(cur_c);
+ else
+ ll_cur = get_linklist(cur_c, level);
+ if (*ll_cur && !isUpdate) {
+ throw std::runtime_error("The newly inserted element should have blank link list");
+ }
+ setListCount(ll_cur, selectedNeighbors.size());
+ tableint *data = (tableint *) (ll_cur + 1);
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
+ if (data[idx] && !isUpdate)
+ throw std::runtime_error("Possible memory corruption");
+ if (level > element_levels_[selectedNeighbors[idx]])
+ throw std::runtime_error("Trying to make a link on a non-existent level");
- std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
+ data[idx] = selectedNeighbors[idx];
+ }
+ }
- linklistsizeint *ll_other;
- if (level == 0)
- ll_other = get_linklist0(selectedNeighbors[idx]);
- else
- ll_other = get_linklist(selectedNeighbors[idx], level);
+ for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
+ std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
- size_t sz_link_list_other = getListCount(ll_other);
+ linklistsizeint *ll_other;
+ if (level == 0)
+ ll_other = get_linklist0(selectedNeighbors[idx]);
+ else
+ ll_other = get_linklist(selectedNeighbors[idx], level);
- if (sz_link_list_other > Mcurmax)
- throw std::runtime_error("Bad value of sz_link_list_other");
- if (selectedNeighbors[idx] == cur_c)
- throw std::runtime_error("Trying to connect an element to itself");
- if (level > element_levels_[selectedNeighbors[idx]])
- throw std::runtime_error("Trying to make a link on a non-existent level");
+ size_t sz_link_list_other = getListCount(ll_other);
- tableint *data = (tableint *) (ll_other + 1);
+ if (sz_link_list_other > Mcurmax)
+ throw std::runtime_error("Bad value of sz_link_list_other");
+ if (selectedNeighbors[idx] == cur_c)
+ throw std::runtime_error("Trying to connect an element to itself");
+ if (level > element_levels_[selectedNeighbors[idx]])
+ throw std::runtime_error("Trying to make a link on a non-existent level");
- bool is_cur_c_present = false;
- if (isUpdate) {
- for (size_t j = 0; j < sz_link_list_other; j++) {
- if (data[j] == cur_c) {
- is_cur_c_present = true;
- break;
- }
+ tableint *data = (tableint *) (ll_other + 1);
+
+ bool is_cur_c_present = false;
+ if (isUpdate) {
+ for (size_t j = 0; j < sz_link_list_other; j++) {
+ if (data[j] == cur_c) {
+ is_cur_c_present = true;
+ break;
}
}
+ }
- // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
- if (!is_cur_c_present) {
- if (sz_link_list_other < Mcurmax) {
- data[sz_link_list_other] = cur_c;
- setListCount(ll_other, sz_link_list_other + 1);
- } else {
- // finding the "weakest" element to replace it with the new one
- dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
- dist_func_param_);
- // Heuristic:
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
- candidates.emplace(d_max, cur_c);
+ // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
+ if (!is_cur_c_present) {
+ if (sz_link_list_other < Mcurmax) {
+ data[sz_link_list_other] = cur_c;
+ setListCount(ll_other, sz_link_list_other + 1);
+ } else {
+ // finding the "weakest" element to replace it with the new one
+ dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
+ dist_func_param_);
+ // Heuristic:
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
+ candidates.emplace(d_max, cur_c);
- for (size_t j = 0; j < sz_link_list_other; j++) {
- candidates.emplace(
- fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
- dist_func_param_), data[j]);
- }
+ for (size_t j = 0; j < sz_link_list_other; j++) {
+ candidates.emplace(
+ fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
+ dist_func_param_), data[j]);
+ }
- getNeighborsByHeuristic2(candidates, Mcurmax);
+ getNeighborsByHeuristic2(candidates, Mcurmax);
- int indx = 0;
- while (candidates.size() > 0) {
- data[indx] = candidates.top().second;
- candidates.pop();
- indx++;
- }
+ int indx = 0;
+ while (candidates.size() > 0) {
+ data[indx] = candidates.top().second;
+ candidates.pop();
+ indx++;
+ }
- setListCount(ll_other, indx);
- // Nearest K:
- /*int indx = -1;
- for (int j = 0; j < sz_link_list_other; j++) {
- dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
- if (d > d_max) {
- indx = j;
- d_max = d;
- }
+ setListCount(ll_other, indx);
+ // Nearest K:
+ /*int indx = -1;
+ for (int j = 0; j < sz_link_list_other; j++) {
+ dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
+ if (d > d_max) {
+ indx = j;
+ d_max = d;
}
- if (indx >= 0) {
- data[indx] = cur_c;
- } */
}
+ if (indx >= 0) {
+ data[indx] = cur_c;
+ } */
}
}
-
- return next_closest_entry_point;
}
- std::mutex global;
- size_t ef_;
+ return next_closest_entry_point;
+ }
- void setEf(size_t ef) {
- ef_ = ef;
- }
+ void resizeIndex(size_t new_max_elements) {
+ if (new_max_elements < cur_element_count)
+ throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
- std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
- if (cur_element_count == 0) return top_candidates;
- tableint currObj = enterpoint_node_;
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
+ delete visited_list_pool_;
+ visited_list_pool_ = new VisitedListPool(1, new_max_elements);
- for (size_t level = maxlevel_; level > 0; level--) {
- bool changed = true;
- while (changed) {
- changed = false;
- int *data;
- data = (int *) get_linklist(currObj,level);
- int size = getListCount(data);
- tableint *datal = (tableint *) (data + 1);
- for (int i = 0; i < size; i++) {
- tableint cand = datal[i];
- if (cand < 0 || cand > max_elements_)
- throw std::runtime_error("cand error");
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
+ element_levels_.resize(new_max_elements);
- if (d < curdist) {
- curdist = d;
- currObj = cand;
- changed = true;
- }
- }
- }
- }
+ std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
- if (num_deleted_) {
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
- ef_);
- top_candidates.swap(top_candidates1);
- }
- else{
- std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
- ef_);
- top_candidates.swap(top_candidates1);
- }
+ // Reallocate base layer
+ char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
+ if (data_level0_memory_new == nullptr)
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
+ data_level0_memory_ = data_level0_memory_new;
- while (top_candidates.size() > k) {
- top_candidates.pop();
- }
- return top_candidates;
- };
+ // Reallocate all other layers
+ char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
+ if (linkLists_new == nullptr)
+ throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
+ linkLists_ = linkLists_new;
- void resizeIndex(size_t new_max_elements){
- if (new_max_elements<cur_element_count)
- throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
+ max_elements_ = new_max_elements;
+ }
- delete visited_list_pool_;
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
+ void saveIndex(const std::string &location) {
+ std::ofstream output(location, std::ios::binary);
+ std::streampos position;
+ writeBinaryPOD(output, offsetLevel0_);
+ writeBinaryPOD(output, max_elements_);
+ writeBinaryPOD(output, cur_element_count);
+ writeBinaryPOD(output, size_data_per_element_);
+ writeBinaryPOD(output, label_offset_);
+ writeBinaryPOD(output, offsetData_);
+ writeBinaryPOD(output, maxlevel_);
+ writeBinaryPOD(output, enterpoint_node_);
+ writeBinaryPOD(output, maxM_);
- element_levels_.resize(new_max_elements);
+ writeBinaryPOD(output, maxM0_);
+ writeBinaryPOD(output, M_);
+ writeBinaryPOD(output, mult_);
+ writeBinaryPOD(output, ef_construction_);
- std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
+ output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
- // Reallocate base layer
- char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_);
- if (data_level0_memory_new == nullptr)
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
- data_level0_memory_ = data_level0_memory_new;
+ for (size_t i = 0; i < cur_element_count; i++) {
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
+ writeBinaryPOD(output, linkListSize);
+ if (linkListSize)
+ output.write(linkLists_[i], linkListSize);
+ }
+ output.close();
+ }
- // Reallocate all other layers
- char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements);
- if (linkLists_new == nullptr)
- throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
- linkLists_ = linkLists_new;
- max_elements_ = new_max_elements;
- }
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i = 0) {
+ std::ifstream input(location, std::ios::binary);
- void saveIndex(const std::string &location) {
- std::ofstream output(location, std::ios::binary);
- std::streampos position;
+ if (!input.is_open())
+ throw std::runtime_error("Cannot open file");
- writeBinaryPOD(output, offsetLevel0_);
- writeBinaryPOD(output, max_elements_);
- writeBinaryPOD(output, cur_element_count);
- writeBinaryPOD(output, size_data_per_element_);
- writeBinaryPOD(output, label_offset_);
- writeBinaryPOD(output, offsetData_);
- writeBinaryPOD(output, maxlevel_);
- writeBinaryPOD(output, enterpoint_node_);
- writeBinaryPOD(output, maxM_);
+ // get file size:
+ input.seekg(0, input.end);
+ std::streampos total_filesize = input.tellg();
+ input.seekg(0, input.beg);
- writeBinaryPOD(output, maxM0_);
- writeBinaryPOD(output, M_);
- writeBinaryPOD(output, mult_);
- writeBinaryPOD(output, ef_construction_);
+ readBinaryPOD(input, offsetLevel0_);
+ readBinaryPOD(input, max_elements_);
+ readBinaryPOD(input, cur_element_count);
- output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
+ size_t max_elements = max_elements_i;
+ if (max_elements < cur_element_count)
+ max_elements = max_elements_;
+ max_elements_ = max_elements;
+ readBinaryPOD(input, size_data_per_element_);
+ readBinaryPOD(input, label_offset_);
+ readBinaryPOD(input, offsetData_);
+ readBinaryPOD(input, maxlevel_);
+ readBinaryPOD(input, enterpoint_node_);
- for (size_t i = 0; i < cur_element_count; i++) {
- unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
- writeBinaryPOD(output, linkListSize);
- if (linkListSize)
- output.write(linkLists_[i], linkListSize);
- }
- output.close();
- }
+ readBinaryPOD(input, maxM_);
+ readBinaryPOD(input, maxM0_);
+ readBinaryPOD(input, M_);
+ readBinaryPOD(input, mult_);
+ readBinaryPOD(input, ef_construction_);
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s, size_t max_elements_i=0) {
- std::ifstream input(location, std::ios::binary);
+ data_size_ = s->get_data_size();
+ fstdistfunc_ = s->get_dist_func();
+ dist_func_param_ = s->get_dist_func_param();
- if (!input.is_open())
- throw std::runtime_error("Cannot open file");
+ auto pos = input.tellg();
- // get file size:
- input.seekg(0,input.end);
- std::streampos total_filesize=input.tellg();
- input.seekg(0,input.beg);
+ /// Optional - check if index is ok:
+ input.seekg(cur_element_count * size_data_per_element_, input.cur);
+ for (size_t i = 0; i < cur_element_count; i++) {
+ if (input.tellg() < 0 || input.tellg() >= total_filesize) {
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
+ }
- readBinaryPOD(input, offsetLevel0_);
- readBinaryPOD(input, max_elements_);
- readBinaryPOD(input, cur_element_count);
+ unsigned int linkListSize;
+ readBinaryPOD(input, linkListSize);
+ if (linkListSize != 0) {
+ input.seekg(linkListSize, input.cur);
+ }
+ }
- size_t max_elements = max_elements_i;
- if(max_elements < cur_element_count)
- max_elements = max_elements_;
- max_elements_ = max_elements;
- readBinaryPOD(input, size_data_per_element_);
- readBinaryPOD(input, label_offset_);
- readBinaryPOD(input, offsetData_);
- readBinaryPOD(input, maxlevel_);
- readBinaryPOD(input, enterpoint_node_);
+ // throw exception if it either corrupted or old index
+ if (input.tellg() != total_filesize)
+ throw std::runtime_error("Index seems to be corrupted or unsupported");
- readBinaryPOD(input, maxM_);
- readBinaryPOD(input, maxM0_);
- readBinaryPOD(input, M_);
- readBinaryPOD(input, mult_);
- readBinaryPOD(input, ef_construction_);
+ input.clear();
+ /// Optional check end
+ input.seekg(pos, input.beg);
- data_size_ = s->get_data_size();
- fstdistfunc_ = s->get_dist_func();
- dist_func_param_ = s->get_dist_func_param();
+ data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
+ if (data_level0_memory_ == nullptr)
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
+ input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
- auto pos=input.tellg();
+ size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
+ size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
+ std::vector<std::mutex>(max_elements).swap(link_list_locks_);
+ std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
- /// Optional - check if index is ok:
+ visited_list_pool_ = new VisitedListPool(1, max_elements);
- input.seekg(cur_element_count * size_data_per_element_,input.cur);
- for (size_t i = 0; i < cur_element_count; i++) {
- if(input.tellg() < 0 || input.tellg()>=total_filesize){
- throw std::runtime_error("Index seems to be corrupted or unsupported");
- }
+ linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
+ if (linkLists_ == nullptr)
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
+ element_levels_ = std::vector<int>(max_elements);
+ revSize_ = 1.0 / mult_;
+ ef_ = 10;
+ for (size_t i = 0; i < cur_element_count; i++) {
+ label_lookup_[getExternalLabel(i)] = i;
+ unsigned int linkListSize;
+ readBinaryPOD(input, linkListSize);
+ if (linkListSize == 0) {
+ element_levels_[i] = 0;
+ linkLists_[i] = nullptr;
+ } else {
+ element_levels_[i] = linkListSize / size_links_per_element_;
+ linkLists_[i] = (char *) malloc(linkListSize);
+ if (linkLists_[i] == nullptr)
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
+ input.read(linkLists_[i], linkListSize);
+ }
+ }
- unsigned int linkListSize;
- readBinaryPOD(input, linkListSize);
- if (linkListSize != 0) {
- input.seekg(linkListSize,input.cur);
- }
+ for (size_t i = 0; i < cur_element_count; i++) {
+ if (isMarkedDeleted(i)) {
+ num_deleted_ += 1;
+ if (allow_replace_deleted_) deleted_elements.insert(i);
}
+ }
- // throw exception if it either corrupted or old index
- if(input.tellg()!=total_filesize)
- throw std::runtime_error("Index seems to be corrupted or unsupported");
+ input.close();
- input.clear();
+ return;
+ }
- /// Optional check end
- input.seekg(pos,input.beg);
+ template<typename data_t>
+ std::vector<data_t> getDataByLabel(labeltype label) const {
+ // lock all operations with element by label
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
- data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
- if (data_level0_memory_ == nullptr)
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
- input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
+ auto search = label_lookup_.find(label);
+ if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
+ throw std::runtime_error("Label not found");
+ }
+ tableint internalId = search->second;
+ lock_table.unlock();
- size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
+ char* data_ptrv = getDataByInternalId(internalId);
+ size_t dim = *((size_t *) dist_func_param_);
+ std::vector<data_t> data;
+ data_t* data_ptr = (data_t*) data_ptrv;
+ for (int i = 0; i < dim; i++) {
+ data.push_back(*data_ptr);
+ data_ptr += 1;
+ }
+ return data;
+ }
- size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
- std::vector<std::mutex>(max_elements).swap(link_list_locks_);
- std::vector<std::mutex>(max_update_element_locks).swap(link_list_update_locks_);
- visited_list_pool_ = new VisitedListPool(1, max_elements);
+ /*
+ * Marks an element with the given label deleted, does NOT really change the current graph.
+ */
+ void markDelete(labeltype label) {
+ // lock all operations with element by label
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
- linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
- if (linkLists_ == nullptr)
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
- element_levels_ = std::vector<int>(max_elements);
- revSize_ = 1.0 / mult_;
- ef_ = 10;
- for (size_t i = 0; i < cur_element_count; i++) {
- label_lookup_[getExternalLabel(i)]=i;
- unsigned int linkListSize;
- readBinaryPOD(input, linkListSize);
- if (linkListSize == 0) {
- element_levels_[i] = 0;
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
+ auto search = label_lookup_.find(label);
+ if (search == label_lookup_.end()) {
+ throw std::runtime_error("Label not found");
+ }
+ tableint internalId = search->second;
+ lock_table.unlock();
- linkLists_[i] = nullptr;
- } else {
- element_levels_[i] = linkListSize / size_links_per_element_;
- linkLists_[i] = (char *) malloc(linkListSize);
- if (linkLists_[i] == nullptr)
- throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
- input.read(linkLists_[i], linkListSize);
- }
- }
+ markDeletedInternal(internalId);
+ }
- for (size_t i = 0; i < cur_element_count; i++) {
- if(isMarkedDeleted(i))
- num_deleted_ += 1;
+
+ /*
+ * Uses the last 16 bits of the memory for the linked list size to store the mark,
+ * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases.
+ */
+ void markDeletedInternal(tableint internalId) {
+ assert(internalId < cur_element_count);
+ if (!isMarkedDeleted(internalId)) {
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
+ *ll_cur |= DELETE_MARK;
+ num_deleted_ += 1;
+ if (allow_replace_deleted_) {
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
+ deleted_elements.insert(internalId);
}
+ } else {
+ throw std::runtime_error("The requested to delete element is already deleted");
+ }
+ }
- input.close();
- return;
+ /*
+ * Removes the deleted mark of the node, does NOT really change the current graph.
+ *
+ * Note: the method is not safe to use when replacement of deleted elements is enabled,
+ * because elements marked as deleted can be completely removed by addPoint
+ */
+ void unmarkDelete(labeltype label) {
+ // lock all operations with element by label
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
+
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
+ auto search = label_lookup_.find(label);
+ if (search == label_lookup_.end()) {
+ throw std::runtime_error("Label not found");
}
+ tableint internalId = search->second;
+ lock_table.unlock();
- template<typename data_t>
- std::vector<data_t> getDataByLabel(labeltype label) const
- {
- tableint label_c;
- auto search = label_lookup_.find(label);
- if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
- throw std::runtime_error("Label not found");
- }
- label_c = search->second;
+ unmarkDeletedInternal(internalId);
+ }
- char* data_ptrv = getDataByInternalId(label_c);
- size_t dim = *((size_t *) dist_func_param_);
- std::vector<data_t> data;
- data_t* data_ptr = (data_t*) data_ptrv;
- for (int i = 0; i < dim; i++) {
- data.push_back(*data_ptr);
- data_ptr += 1;
- }
- return data;
- }
- static const unsigned char DELETE_MARK = 0x01;
- // static const unsigned char REUSE_MARK = 0x10;
- /**
- * Marks an element with the given label deleted, does NOT really change the current graph.
- * @param label
- */
- void markDelete(labeltype label)
- {
- auto search = label_lookup_.find(label);
- if (search == label_lookup_.end()) {
- throw std::runtime_error("Label not found");
- }
- tableint internalId = search->second;
- markDeletedInternal(internalId);
- }
- /**
- * Uses the first 8 bits of the memory for the linked list to store the mark,
- * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
- * @param internalId
- */
- void markDeletedInternal(tableint internalId) {
- assert(internalId < cur_element_count);
- if (!isMarkedDeleted(internalId))
- {
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
- *ll_cur |= DELETE_MARK;
- num_deleted_ += 1;
+ /*
+ * Remove the deleted mark of the node.
+ */
+ void unmarkDeletedInternal(tableint internalId) {
+ assert(internalId < cur_element_count);
+ if (isMarkedDeleted(internalId)) {
+ unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
+ *ll_cur &= ~DELETE_MARK;
+ num_deleted_ -= 1;
+ if (allow_replace_deleted_) {
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
+ deleted_elements.erase(internalId);
}
- else
- {
- throw std::runtime_error("The requested to delete element is already deleted");
- }
+ } else {
+ throw std::runtime_error("The requested to undelete element is not deleted");
}
+ }
- /**
- * Remove the deleted mark of the node, does NOT really change the current graph.
- * @param label
- */
- void unmarkDelete(labeltype label)
- {
- auto search = label_lookup_.find(label);
- if (search == label_lookup_.end()) {
- throw std::runtime_error("Label not found");
- }
- tableint internalId = search->second;
- unmarkDeletedInternal(internalId);
- }
- /**
- * Remove the deleted mark of the node.
- * @param internalId
- */
- void unmarkDeletedInternal(tableint internalId) {
- assert(internalId < cur_element_count);
- if (isMarkedDeleted(internalId))
- {
- unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
- *ll_cur &= ~DELETE_MARK;
- num_deleted_ -= 1;
- }
- else
- {
- throw std::runtime_error("The requested to undelete element is not deleted");
- }
- }
+ /*
+ * Checks the first 16 bits of the memory to see if the element is marked deleted.
+ */
+ bool isMarkedDeleted(tableint internalId) const {
+ unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2;
+ return *ll_cur & DELETE_MARK;
+ }
- /**
- * Checks the first 8 bits of the memory to see if the element is marked deleted.
- * @param internalId
- * @return
- */
- bool isMarkedDeleted(tableint internalId) const {
- unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
- return *ll_cur & DELETE_MARK;
- }
- unsigned short int getListCount(linklistsizeint * ptr) const {
- return *((unsigned short int *)ptr);
+ unsigned short int getListCount(linklistsizeint * ptr) const {
+ return *((unsigned short int *)ptr);
+ }
+
+
+ void setListCount(linklistsizeint * ptr, unsigned short int size) const {
+ *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
+ }
+
+
+ /*
+ * Adds point. Updates the point if it is already in the index.
+ * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
+ */
+ void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
+ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
+ throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
}
- void setListCount(linklistsizeint * ptr, unsigned short int size) const {
- *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
+ // lock all operations with element by label
+ std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
+ if (!replace_deleted) {
+ addPoint(data_point, label, -1);
+ return;
}
+ // check if there is vacant place
+ tableint internal_id_replaced;
+ std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
+ bool is_vacant_place = !deleted_elements.empty();
+ if (is_vacant_place) {
+ internal_id_replaced = *deleted_elements.begin();
+ deleted_elements.erase(internal_id_replaced);
+ }
+ lock_deleted_elements.unlock();
- void addPoint(const void *data_point, labeltype label) {
- addPoint(data_point, label,-1);
+ // if there is no vacant place then add or update point
+ // else add point to vacant place
+ if (!is_vacant_place) {
+ addPoint(data_point, label, -1);
+ } else {
+ // we assume that there are no concurrent operations on deleted element
+ labeltype label_replaced = getExternalLabel(internal_id_replaced);
+ setExternalLabel(internal_id_replaced, label);
+
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
+ label_lookup_.erase(label_replaced);
+ label_lookup_[label] = internal_id_replaced;
+ lock_table.unlock();
+
+ unmarkDeletedInternal(internal_id_replaced);
+ updatePoint(data_point, internal_id_replaced, 1.0);
}
+ }
- void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
- // update the feature vector associated with existing point with new vector
- memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
- int maxLevelCopy = maxlevel_;
- tableint entryPointCopy = enterpoint_node_;
- // If point to be updated is entry point and graph just contains single element then just return.
- if (entryPointCopy == internalId && cur_element_count == 1)
- return;
+ void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
+ // update the feature vector associated with existing point with new vector
+ memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
- int elemLevel = element_levels_[internalId];
- std::uniform_real_distribution<float> distribution(0.0, 1.0);
- for (int layer = 0; layer <= elemLevel; layer++) {
- std::unordered_set<tableint> sCand;
- std::unordered_set<tableint> sNeigh;
- std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
- if (listOneHop.size() == 0)
- continue;
+ int maxLevelCopy = maxlevel_;
+ tableint entryPointCopy = enterpoint_node_;
+ // If point to be updated is entry point and graph just contains single element then just return.
+ if (entryPointCopy == internalId && cur_element_count == 1)
+ return;
- sCand.insert(internalId);
+ int elemLevel = element_levels_[internalId];
+ std::uniform_real_distribution<float> distribution(0.0, 1.0);
+ for (int layer = 0; layer <= elemLevel; layer++) {
+ std::unordered_set<tableint> sCand;
+ std::unordered_set<tableint> sNeigh;
+ std::vector<tableint> listOneHop = getConnectionsWithLock(internalId, layer);
+ if (listOneHop.size() == 0)
+ continue;
- for (auto&& elOneHop : listOneHop) {
- sCand.insert(elOneHop);
+ sCand.insert(internalId);
- if (distribution(update_probability_generator_) > updateNeighborProbability)
- continue;
+ for (auto&& elOneHop : listOneHop) {
+ sCand.insert(elOneHop);
- sNeigh.insert(elOneHop);
+ if (distribution(update_probability_generator_) > updateNeighborProbability)
+ continue;
- std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
- for (auto&& elTwoHop : listTwoHop) {
- sCand.insert(elTwoHop);
- }
+ sNeigh.insert(elOneHop);
+
+ std::vector<tableint> listTwoHop = getConnectionsWithLock(elOneHop, layer);
+ for (auto&& elTwoHop : listTwoHop) {
+ sCand.insert(elTwoHop);
}
+ }
- for (auto&& neigh : sNeigh) {
- // if (neigh == internalId)
- // continue;
+ for (auto&& neigh : sNeigh) {
+ // if (neigh == internalId)
+ // continue;
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
- size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
- size_t elementsToKeep = std::min(ef_construction_, size);
- for (auto&& cand : sCand) {
- if (cand == neigh)
- continue;
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
+ size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
+ size_t elementsToKeep = std::min(ef_construction_, size);
+ for (auto&& cand : sCand) {
+ if (cand == neigh)
+ continue;
- dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
- if (candidates.size() < elementsToKeep) {
+ dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
+ if (candidates.size() < elementsToKeep) {
+ candidates.emplace(distance, cand);
+ } else {
+ if (distance < candidates.top().first) {
+ candidates.pop();
candidates.emplace(distance, cand);
- } else {
- if (distance < candidates.top().first) {
- candidates.pop();
- candidates.emplace(distance, cand);
- }
}
}
+ }
- // Retrieve neighbours using heuristic and set connections.
- getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
+ // Retrieve neighbours using heuristic and set connections.
+ getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
- {
- std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
- linklistsizeint *ll_cur;
- ll_cur = get_linklist_at_level(neigh, layer);
- size_t candSize = candidates.size();
- setListCount(ll_cur, candSize);
- tableint *data = (tableint *) (ll_cur + 1);
- for (size_t idx = 0; idx < candSize; idx++) {
- data[idx] = candidates.top().second;
- candidates.pop();
- }
+ {
+ std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
+ linklistsizeint *ll_cur;
+ ll_cur = get_linklist_at_level(neigh, layer);
+ size_t candSize = candidates.size();
+ setListCount(ll_cur, candSize);
+ tableint *data = (tableint *) (ll_cur + 1);
+ for (size_t idx = 0; idx < candSize; idx++) {
+ data[idx] = candidates.top().second;
+ candidates.pop();
}
}
}
+ }
- repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
- };
+ repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
+ }
- void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
- tableint currObj = entryPointInternalId;
- if (dataPointLevel < maxLevel) {
- dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
- for (int level = maxLevel; level > dataPointLevel; level--) {
- bool changed = true;
- while (changed) {
- changed = false;
- unsigned int *data;
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
- data = get_linklist_at_level(currObj,level);
- int size = getListCount(data);
- tableint *datal = (tableint *) (data + 1);
+
+ void repairConnectionsForUpdate(
+ const void *dataPoint,
+ tableint entryPointInternalId,
+ tableint dataPointInternalId,
+ int dataPointLevel,
+ int maxLevel) {
+ tableint currObj = entryPointInternalId;
+ if (dataPointLevel < maxLevel) {
+ dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
+ for (int level = maxLevel; level > dataPointLevel; level--) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ unsigned int *data;
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
+ data = get_linklist_at_level(currObj, level);
+ int size = getListCount(data);
+ tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
- _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
#endif
- for (int i = 0; i < size; i++) {
+ for (int i = 0; i < size; i++) {
#ifdef USE_SSE
- _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
+ _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
#endif
- tableint cand = datal[i];
- dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
- if (d < curdist) {
- curdist = d;
- currObj = cand;
- changed = true;
- }
+ tableint cand = datal[i];
+ dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
+ if (d < curdist) {
+ curdist = d;
+ currObj = cand;
+ changed = true;
}
}
}
}
+ }
- if (dataPointLevel > maxLevel)
- throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
+ if (dataPointLevel > maxLevel)
+ throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
- for (int level = dataPointLevel; level >= 0; level--) {
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
- currObj, dataPoint, level);
+ for (int level = dataPointLevel; level >= 0; level--) {
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
+ currObj, dataPoint, level);
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
- while (topCandidates.size() > 0) {
- if (topCandidates.top().second != dataPointInternalId)
- filteredTopCandidates.push(topCandidates.top());
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
+ while (topCandidates.size() > 0) {
+ if (topCandidates.top().second != dataPointInternalId)
+ filteredTopCandidates.push(topCandidates.top());
- topCandidates.pop();
- }
+ topCandidates.pop();
+ }
- // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
- // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
- if (filteredTopCandidates.size() > 0) {
- bool epDeleted = isMarkedDeleted(entryPointInternalId);
- if (epDeleted) {
- filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
- if (filteredTopCandidates.size() > ef_construction_)
- filteredTopCandidates.pop();
- }
-
- currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
+ // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
+ // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
+ if (filteredTopCandidates.size() > 0) {
+ bool epDeleted = isMarkedDeleted(entryPointInternalId);
+ if (epDeleted) {
+ filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
+ if (filteredTopCandidates.size() > ef_construction_)
+ filteredTopCandidates.pop();
}
+
+ currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
}
}
+ }
- std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
- std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
- unsigned int *data = get_linklist_at_level(internalId, level);
- int size = getListCount(data);
- std::vector<tableint> result(size);
- tableint *ll = (tableint *) (data + 1);
- memcpy(result.data(), ll,size * sizeof(tableint));
- return result;
- };
- tableint addPoint(const void *data_point, labeltype label, int level) {
+ std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
+ std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
+ unsigned int *data = get_linklist_at_level(internalId, level);
+ int size = getListCount(data);
+ std::vector<tableint> result(size);
+ tableint *ll = (tableint *) (data + 1);
+ memcpy(result.data(), ll, size * sizeof(tableint));
+ return result;
+ }
- tableint cur_c = 0;
- {
- // Checking if the element with the same label already exists
- // if so, updating it *instead* of creating a new element.
- std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
- auto search = label_lookup_.find(label);
- if (search != label_lookup_.end()) {
- tableint existingInternalId = search->second;
- templock_curr.unlock();
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
-
+ tableint addPoint(const void *data_point, labeltype label, int level) {
+ tableint cur_c = 0;
+ {
+ // Checking if the element with the same label already exists
+ // if so, updating it *instead* of creating a new element.
+ std::unique_lock <std::mutex> lock_table(label_lookup_lock);
+ auto search = label_lookup_.find(label);
+ if (search != label_lookup_.end()) {
+ tableint existingInternalId = search->second;
+ if (allow_replace_deleted_) {
if (isMarkedDeleted(existingInternalId)) {
- unmarkDeletedInternal(existingInternalId);
+ throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
}
- updatePoint(data_point, existingInternalId, 1.0);
+ }
+ lock_table.unlock();
- return existingInternalId;
+ if (isMarkedDeleted(existingInternalId)) {
+ unmarkDeletedInternal(existingInternalId);
}
+ updatePoint(data_point, existingInternalId, 1.0);
- if (cur_element_count >= max_elements_) {
- throw std::runtime_error("The number of elements exceeds the specified limit");
- };
+ return existingInternalId;
+ }
- cur_c = cur_element_count;
- cur_element_count++;
- label_lookup_[label] = cur_c;
+ if (cur_element_count >= max_elements_) {
+ throw std::runtime_error("The number of elements exceeds the specified limit");
}
- // Take update lock to prevent race conditions on an element with insertion/update at the same time.
- std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
- std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
- int curlevel = getRandomLevel(mult_);
- if (level > 0)
- curlevel = level;
+ cur_c = cur_element_count;
+ cur_element_count++;
+ label_lookup_[label] = cur_c;
+ }
- element_levels_[cur_c] = curlevel;
+ std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
+ int curlevel = getRandomLevel(mult_);
+ if (level > 0)
+ curlevel = level;
+ element_levels_[cur_c] = curlevel;
- std::unique_lock <std::mutex> templock(global);
- int maxlevelcopy = maxlevel_;
- if (curlevel <= maxlevelcopy)
- templock.unlock();
- tableint currObj = enterpoint_node_;
- tableint enterpoint_copy = enterpoint_node_;
+ std::unique_lock <std::mutex> templock(global);
+ int maxlevelcopy = maxlevel_;
+ if (curlevel <= maxlevelcopy)
+ templock.unlock();
+ tableint currObj = enterpoint_node_;
+ tableint enterpoint_copy = enterpoint_node_;
+ memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
- memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
+ // Initialisation of the data and label
+ memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
+ memcpy(getDataByInternalId(cur_c), data_point, data_size_);
- // Initialisation of the data and label
- memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
- memcpy(getDataByInternalId(cur_c), data_point, data_size_);
+ if (curlevel) {
+ linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
+ if (linkLists_[cur_c] == nullptr)
+ throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
+ memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
+ }
+ if ((signed)currObj != -1) {
+ if (curlevel < maxlevelcopy) {
+ dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
+ for (int level = maxlevelcopy; level > curlevel; level--) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ unsigned int *data;
+ std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
+ data = get_linklist(currObj, level);
+ int size = getListCount(data);
- if (curlevel) {
- linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
- if (linkLists_[cur_c] == nullptr)
- throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
- memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
- }
-
- if ((signed)currObj != -1) {
-
- if (curlevel < maxlevelcopy) {
-
- dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
- for (int level = maxlevelcopy; level > curlevel; level--) {
-
-
- bool changed = true;
- while (changed) {
- changed = false;
- unsigned int *data;
- std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
- data = get_linklist(currObj,level);
- int size = getListCount(data);
-
- tableint *datal = (tableint *) (data + 1);
- for (int i = 0; i < size; i++) {
- tableint cand = datal[i];
- if (cand < 0 || cand > max_elements_)
- throw std::runtime_error("cand error");
- dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
- if (d < curdist) {
- curdist = d;
- currObj = cand;
- changed = true;
- }
+ tableint *datal = (tableint *) (data + 1);
+ for (int i = 0; i < size; i++) {
+ tableint cand = datal[i];
+ if (cand < 0 || cand > max_elements_)
+ throw std::runtime_error("cand error");
+ dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
+ if (d < curdist) {
+ curdist = d;
+ currObj = cand;
+ changed = true;
}
}
}
}
+ }
- bool epDeleted = isMarkedDeleted(enterpoint_copy);
- for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
- if (level > maxlevelcopy || level < 0) // possible?
- throw std::runtime_error("Level error");
+ bool epDeleted = isMarkedDeleted(enterpoint_copy);
+ for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
+ if (level > maxlevelcopy || level < 0) // possible?
+ throw std::runtime_error("Level error");
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
- currObj, data_point, level);
- if (epDeleted) {
- top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
- if (top_candidates.size() > ef_construction_)
- top_candidates.pop();
- }
- currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
+ currObj, data_point, level);
+ if (epDeleted) {
+ top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
+ if (top_candidates.size() > ef_construction_)
+ top_candidates.pop();
}
-
-
- } else {
- // Do nothing for the first element
- enterpoint_node_ = 0;
- maxlevel_ = curlevel;
-
+ currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
}
+ } else {
+ // Do nothing for the first element
+ enterpoint_node_ = 0;
+ maxlevel_ = curlevel;
+ }
- //Releasing lock for the maximum level
- if (curlevel > maxlevelcopy) {
- enterpoint_node_ = cur_c;
- maxlevel_ = curlevel;
- }
- return cur_c;
- };
+ // Releasing lock for the maximum level
+ if (curlevel > maxlevelcopy) {
+ enterpoint_node_ = cur_c;
+ maxlevel_ = curlevel;
+ }
+ return cur_c;
+ }
- std::priority_queue<std::pair<dist_t, labeltype >>
- searchKnn(const void *query_data, size_t k) const {
- std::priority_queue<std::pair<dist_t, labeltype >> result;
- if (cur_element_count == 0) return result;
- tableint currObj = enterpoint_node_;
- dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
+ std::priority_queue<std::pair<dist_t, labeltype >>
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
+ std::priority_queue<std::pair<dist_t, labeltype >> result;
+ if (cur_element_count == 0) return result;
- for (int level = maxlevel_; level > 0; level--) {
- bool changed = true;
- while (changed) {
- changed = false;
- unsigned int *data;
+ tableint currObj = enterpoint_node_;
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
- data = (unsigned int *) get_linklist(currObj, level);
- int size = getListCount(data);
- metric_hops++;
- metric_distance_computations+=size;
+ for (int level = maxlevel_; level > 0; level--) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ unsigned int *data;
- tableint *datal = (tableint *) (data + 1);
- for (int i = 0; i < size; i++) {
- tableint cand = datal[i];
- if (cand < 0 || cand > max_elements_)
- throw std::runtime_error("cand error");
- dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
+ data = (unsigned int *) get_linklist(currObj, level);
+ int size = getListCount(data);
+ metric_hops++;
+ metric_distance_computations+=size;
- if (d < curdist) {
- curdist = d;
- currObj = cand;
- changed = true;
- }
+ tableint *datal = (tableint *) (data + 1);
+ for (int i = 0; i < size; i++) {
+ tableint cand = datal[i];
+ if (cand < 0 || cand > max_elements_)
+ throw std::runtime_error("cand error");
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
+
+ if (d < curdist) {
+ curdist = d;
+ currObj = cand;
+ changed = true;
}
}
}
+ }
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
- if (num_deleted_) {
- top_candidates=searchBaseLayerST<true,true>(
- currObj, query_data, std::max(ef_, k));
- }
- else{
- top_candidates=searchBaseLayerST<false,true>(
- currObj, query_data, std::max(ef_, k));
- }
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
+ if (num_deleted_) {
+ top_candidates = searchBaseLayerST<true, true>(
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
+ } else {
+ top_candidates = searchBaseLayerST<false, true>(
+ currObj, query_data, std::max(ef_, k), isIdAllowed);
+ }
- while (top_candidates.size() > k) {
- top_candidates.pop();
- }
- while (top_candidates.size() > 0) {
- std::pair<dist_t, tableint> rez = top_candidates.top();
- result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
- top_candidates.pop();
- }
- return result;
- };
+ while (top_candidates.size() > k) {
+ top_candidates.pop();
+ }
+ while (top_candidates.size() > 0) {
+ std::pair<dist_t, tableint> rez = top_candidates.top();
+ result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
+ top_candidates.pop();
+ }
+ return result;
+ }
- void checkIntegrity(){
- int connections_checked=0;
- std::vector <int > inbound_connections_num(cur_element_count,0);
- for(int i = 0;i < cur_element_count; i++){
- for(int l = 0;l <= element_levels_[i]; l++){
- linklistsizeint *ll_cur = get_linklist_at_level(i,l);
- int size = getListCount(ll_cur);
- tableint *data = (tableint *) (ll_cur + 1);
- std::unordered_set<tableint> s;
- for (int j=0; j<size; j++){
- assert(data[j] > 0);
- assert(data[j] < cur_element_count);
- assert (data[j] != i);
- inbound_connections_num[data[j]]++;
- s.insert(data[j]);
- connections_checked++;
- }
- assert(s.size() == size);
+ void checkIntegrity() {
+ int connections_checked = 0;
+ std::vector <int > inbound_connections_num(cur_element_count, 0);
+ for (int i = 0; i < cur_element_count; i++) {
+ for (int l = 0; l <= element_levels_[i]; l++) {
+ linklistsizeint *ll_cur = get_linklist_at_level(i, l);
+ int size = getListCount(ll_cur);
+ tableint *data = (tableint *) (ll_cur + 1);
+ std::unordered_set<tableint> s;
+ for (int j = 0; j < size; j++) {
+ assert(data[j] > 0);
+ assert(data[j] < cur_element_count);
+ assert(data[j] != i);
+ inbound_connections_num[data[j]]++;
+ s.insert(data[j]);
+ connections_checked++;
}
+ assert(s.size() == size);
}
- if(cur_element_count > 1){
- int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
- for(int i=0; i < cur_element_count; i++){
- assert(inbound_connections_num[i] > 0);
- min1=std::min(inbound_connections_num[i],min1);
- max1=std::max(inbound_connections_num[i],max1);
- }
- std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
+ }
+ if (cur_element_count > 1) {
+ int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0];
+ for (int i=0; i < cur_element_count; i++) {
+ assert(inbound_connections_num[i] > 0);
+ min1 = std::min(inbound_connections_num[i], min1);
+ max1 = std::max(inbound_connections_num[i], max1);
}
- std::cout << "integrity ok, checked " << connections_checked << " connections\n";
-
+ std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
}
-
- };
-
-}
+ std::cout << "integrity ok, checked " << connections_checked << " connections\n";
+ }
+};
+} // namespace hnswlib