ext/hnswlib/src/hnswalg.h in hnswlib-0.8.1 vs ext/hnswlib/src/hnswalg.h in hnswlib-0.9.0
- old
+ new
@@ -6,10 +6,11 @@
#include <random>
#include <stdlib.h>
#include <assert.h>
#include <unordered_set>
#include <list>
+#include <memory>
namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;
@@ -31,11 +32,11 @@
size_t ef_{ 0 };
double mult_{0.0}, revSize_{0.0};
int maxlevel_{0};
- VisitedListPool *visited_list_pool_{nullptr};
+ std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};
// Locks operations with element by label value
mutable std::vector<std::mutex> label_op_locks_;
std::mutex global;
@@ -91,20 +92,26 @@
size_t max_elements,
size_t M = 16,
size_t ef_construction = 200,
size_t random_seed = 100,
bool allow_replace_deleted = false)
- : link_list_locks_(max_elements),
- label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
+ : label_op_locks_(MAX_LABEL_OPERATION_LOCKS),
+ link_list_locks_(max_elements),
element_levels_(max_elements),
allow_replace_deleted_(allow_replace_deleted) {
max_elements_ = max_elements;
num_deleted_ = 0;
data_size_ = s->get_data_size();
fstdistfunc_ = s->get_dist_func();
dist_func_param_ = s->get_dist_func_param();
- M_ = M;
+ if ( M <= 10000 ) {
+ M_ = M;
+ } else {
+ HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
+ HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
+ M_ = 10000;
+ }
maxM_ = M_;
maxM0_ = M_ * 2;
ef_construction_ = std::max(ef_construction, M_);
ef_ = 10;
@@ -121,11 +128,11 @@
if (data_level0_memory_ == nullptr)
throw std::runtime_error("Not enough memory");
cur_element_count = 0;
- visited_list_pool_ = new VisitedListPool(1, max_elements);
+ visited_list_pool_ = std::unique_ptr<VisitedListPool>(new VisitedListPool(1, max_elements));
// initializations for special treatment of the first node
enterpoint_node_ = -1;
maxlevel_ = -1;
@@ -137,17 +144,24 @@
revSize_ = 1.0 / mult_;
}
~HierarchicalNSW() {
+ clear();
+ }
+
+ void clear() {
free(data_level0_memory_);
+ data_level0_memory_ = nullptr;
for (tableint i = 0; i < cur_element_count; i++) {
if (element_levels_[i] > 0)
free(linkLists_[i]);
}
free(linkLists_);
- delete visited_list_pool_;
+ linkLists_ = nullptr;
+ cur_element_count = 0;
+ visited_list_pool_.reset(nullptr);
}
struct CompareByFirst {
constexpr bool operator()(std::pair<dist_t, tableint> const& a,
@@ -290,38 +304,59 @@
return top_candidates;
}
- template <bool has_deletions, bool collect_metrics = false>
+ // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
+ template <bool bare_bone_search = true, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
- searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
+ searchBaseLayerST(
+ tableint ep_id,
+ const void *data_point,
+ size_t ef,
+ BaseFilterFunctor* isIdAllowed = nullptr,
+ BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
dist_t lowerBound;
- if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
- dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
+ if (bare_bone_search ||
+ (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) {
+ char* ep_data = getDataByInternalId(ep_id);
+ dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
+ if (!bare_bone_search && stop_condition) {
+ stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist);
+ }
candidate_set.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
}
visited_array[ep_id] = visited_array_tag;
while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
+ dist_t candidate_dist = -current_node_pair.first;
- if ((-current_node_pair.first) > lowerBound &&
- (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) {
+ bool flag_stop_search;
+ if (bare_bone_search) {
+ flag_stop_search = candidate_dist > lowerBound;
+ } else {
+ if (stop_condition) {
+ flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound);
+ } else {
+ flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef;
+ }
+ }
+ if (flag_stop_search) {
break;
}
candidate_set.pop();
tableint current_node_id = current_node_pair.second;
@@ -352,23 +387,49 @@
visited_array[candidate_id] = visited_array_tag;
char *currObj1 = (getDataByInternalId(candidate_id));
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
- if (top_candidates.size() < ef || lowerBound > dist) {
+ bool flag_consider_candidate;
+ if (!bare_bone_search && stop_condition) {
+ flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound);
+ } else {
+ flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;
+ }
+
+ if (flag_consider_candidate) {
candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_, ///////////
_MM_HINT_T0); ////////////////////////
#endif
- if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
+ if (bare_bone_search ||
+ (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
top_candidates.emplace(dist, candidate_id);
+ if (!bare_bone_search && stop_condition) {
+ stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist);
+ }
+ }
- if (top_candidates.size() > ef)
+ bool flag_remove_extra = false;
+ if (!bare_bone_search && stop_condition) {
+ flag_remove_extra = stop_condition->should_remove_extra();
+ } else {
+ flag_remove_extra = top_candidates.size() > ef;
+ }
+ while (flag_remove_extra) {
+ tableint id = top_candidates.top().second;
top_candidates.pop();
+ if (!bare_bone_search && stop_condition) {
+ stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist);
+ flag_remove_extra = stop_condition->should_remove_extra();
+ } else {
+ flag_remove_extra = top_candidates.size() > ef;
+ }
+ }
if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
@@ -379,12 +440,12 @@
return top_candidates;
}
void getNeighborsByHeuristic2(
- std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
- const size_t M) {
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
+ const size_t M) {
if (top_candidates.size() < M) {
return;
}
std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
@@ -572,12 +633,11 @@
void resizeIndex(size_t new_max_elements) {
if (new_max_elements < cur_element_count)
throw std::runtime_error("Cannot resize, max element is less than the current number of elements");
- delete visited_list_pool_;
- visited_list_pool_ = new VisitedListPool(1, new_max_elements);
+ visited_list_pool_.reset(new VisitedListPool(1, new_max_elements));
element_levels_.resize(new_max_elements);
std::vector<std::mutex>(new_max_elements).swap(link_list_locks_);
@@ -594,11 +654,37 @@
linkLists_ = linkLists_new;
max_elements_ = new_max_elements;
}
+ size_t indexFileSize() const {
+ size_t size = 0;
+ size += sizeof(offsetLevel0_);
+ size += sizeof(max_elements_);
+ size += sizeof(cur_element_count);
+ size += sizeof(size_data_per_element_);
+ size += sizeof(label_offset_);
+ size += sizeof(offsetData_);
+ size += sizeof(maxlevel_);
+ size += sizeof(enterpoint_node_);
+ size += sizeof(maxM_);
+ size += sizeof(maxM0_);
+ size += sizeof(M_);
+ size += sizeof(mult_);
+ size += sizeof(ef_construction_);
+
+ size += cur_element_count * size_data_per_element_;
+
+ for (size_t i = 0; i < cur_element_count; i++) {
+ unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
+ size += sizeof(linkListSize);
+ size += linkListSize;
+ }
+ return size;
+ }
+
void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;
writeBinaryPOD(output, offsetLevel0_);
@@ -632,10 +718,11 @@
std::ifstream input(location, std::ios::binary);
if (!input.is_open())
throw std::runtime_error("Cannot open file");
+ clear();
// get file size:
input.seekg(0, input.end);
std::streampos total_filesize = input.tellg();
input.seekg(0, input.beg);
@@ -697,11 +784,11 @@
size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
std::vector<std::mutex>(max_elements).swap(link_list_locks_);
std::vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_);
- visited_list_pool_ = new VisitedListPool(1, max_elements);
+ visited_list_pool_.reset(new VisitedListPool(1, max_elements));
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
if (linkLists_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
element_levels_ = std::vector<int>(max_elements);
@@ -751,11 +838,11 @@
char* data_ptrv = getDataByInternalId(internalId);
size_t dim = *((size_t *) dist_func_param_);
std::vector<data_t> data;
data_t* data_ptr = (data_t*) data_ptrv;
- for (int i = 0; i < dim; i++) {
+ for (size_t i = 0; i < dim; i++) {
data.push_back(*data_ptr);
data_ptr += 1;
}
return data;
}
@@ -1215,15 +1302,16 @@
}
}
}
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
- if (num_deleted_) {
- top_candidates = searchBaseLayerST<true, true>(
+ bool bare_bone_search = !num_deleted_ && !isIdAllowed;
+ if (bare_bone_search) {
+ top_candidates = searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
} else {
- top_candidates = searchBaseLayerST<false, true>(
+ top_candidates = searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
}
while (top_candidates.size() > k) {
top_candidates.pop();
@@ -1235,20 +1323,73 @@
}
return result;
}
+ std::vector<std::pair<dist_t, labeltype >>
+ searchStopConditionClosest(
+ const void *query_data,
+ BaseSearchStopCondition<dist_t>& stop_condition,
+ BaseFilterFunctor* isIdAllowed = nullptr) const {
+ std::vector<std::pair<dist_t, labeltype >> result;
+ if (cur_element_count == 0) return result;
+
+ tableint currObj = enterpoint_node_;
+ dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
+
+ for (int level = maxlevel_; level > 0; level--) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ unsigned int *data;
+
+ data = (unsigned int *) get_linklist(currObj, level);
+ int size = getListCount(data);
+ metric_hops++;
+ metric_distance_computations+=size;
+
+ tableint *datal = (tableint *) (data + 1);
+ for (int i = 0; i < size; i++) {
+ tableint cand = datal[i];
+ if (cand < 0 || cand > max_elements_)
+ throw std::runtime_error("cand error");
+ dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
+
+ if (d < curdist) {
+ curdist = d;
+ currObj = cand;
+ changed = true;
+ }
+ }
+ }
+ }
+
+ std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
+ top_candidates = searchBaseLayerST<false>(currObj, query_data, 0, isIdAllowed, &stop_condition);
+
+ size_t sz = top_candidates.size();
+ result.resize(sz);
+ while (!top_candidates.empty()) {
+ result[--sz] = top_candidates.top();
+ top_candidates.pop();
+ }
+
+ stop_condition.filter_results(result);
+
+ return result;
+ }
+
+
void checkIntegrity() {
int connections_checked = 0;
std::vector <int > inbound_connections_num(cur_element_count, 0);
for (int i = 0; i < cur_element_count; i++) {
for (int l = 0; l <= element_levels_[i]; l++) {
linklistsizeint *ll_cur = get_linklist_at_level(i, l);
int size = getListCount(ll_cur);
tableint *data = (tableint *) (ll_cur + 1);
std::unordered_set<tableint> s;
for (int j = 0; j < size; j++) {
- assert(data[j] > 0);
assert(data[j] < cur_element_count);
assert(data[j] != i);
inbound_connections_num[data[j]]++;
s.insert(data[j]);
connections_checked++;