ext/hnswlib/src/bruteforce.h in hnswlib-0.6.2 vs ext/hnswlib/src/bruteforce.h in hnswlib-0.7.0
- old
+ new
@@ -1,157 +1,168 @@
#pragma once
#include <unordered_map>
#include <fstream>
#include <mutex>
#include <algorithm>
+#include <assert.h>
namespace hnswlib {
- template<typename dist_t>
- class BruteforceSearch : public AlgorithmInterface<dist_t> {
- public:
- BruteforceSearch() : data_(nullptr) { }
- BruteforceSearch(SpaceInterface <dist_t> *s) {
+template<typename dist_t>
+class BruteforceSearch : public AlgorithmInterface<dist_t> {
+ public:
+ char *data_;
+ size_t maxelements_;
+ size_t cur_element_count;
+ size_t size_per_element_;
- }
- BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location) {
- loadIndex(location, s);
- }
+ size_t data_size_;
+ DISTFUNC <dist_t> fstdistfunc_;
+ void *dist_func_param_;
+ std::mutex index_lock;
- BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
- maxelements_ = maxElements;
- data_size_ = s->get_data_size();
- fstdistfunc_ = s->get_dist_func();
- dist_func_param_ = s->get_dist_func_param();
- size_per_element_ = data_size_ + sizeof(labeltype);
- data_ = (char *) malloc(maxElements * size_per_element_);
- if (data_ == nullptr)
- std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
- cur_element_count = 0;
- }
+ std::unordered_map<labeltype, size_t > dict_external_to_internal;
- ~BruteforceSearch() {
- if (data_) free(data_);
- }
+ BruteforceSearch() : data_(nullptr) { }
- char *data_;
- size_t maxelements_;
- size_t cur_element_count;
- size_t size_per_element_;
+ BruteforceSearch(SpaceInterface <dist_t> *s)
+ : data_(nullptr),
+ maxelements_(0),
+ cur_element_count(0),
+ size_per_element_(0),
+ data_size_(0),
+ dist_func_param_(nullptr) {
+ }
- size_t data_size_;
- DISTFUNC <dist_t> fstdistfunc_;
- void *dist_func_param_;
- std::mutex index_lock;
- std::unordered_map<labeltype,size_t > dict_external_to_internal;
+ BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
+ : data_(nullptr),
+ maxelements_(0),
+ cur_element_count(0),
+ size_per_element_(0),
+ data_size_(0),
+ dist_func_param_(nullptr) {
+ loadIndex(location, s);
+ }
- void addPoint(const void *datapoint, labeltype label) {
- int idx;
- {
- std::unique_lock<std::mutex> lock(index_lock);
+ BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
+ maxelements_ = maxElements;
+ data_size_ = s->get_data_size();
+ fstdistfunc_ = s->get_dist_func();
+ dist_func_param_ = s->get_dist_func_param();
+ size_per_element_ = data_size_ + sizeof(labeltype);
+ data_ = (char *) malloc(maxElements * size_per_element_);
+ if (data_ == nullptr)
+ throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
+ cur_element_count = 0;
+ }
+ ~BruteforceSearch() {
+ free(data_);
+ }
- auto search=dict_external_to_internal.find(label);
- if (search != dict_external_to_internal.end()) {
- idx=search->second;
+
+ void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
+ int idx;
+ {
+ std::unique_lock<std::mutex> lock(index_lock);
+
+ auto search = dict_external_to_internal.find(label);
+ if (search != dict_external_to_internal.end()) {
+ idx = search->second;
+ } else {
+ if (cur_element_count >= maxelements_) {
+ throw std::runtime_error("The number of elements exceeds the specified limit\n");
}
- else{
- if (cur_element_count >= maxelements_) {
- throw std::runtime_error("The number of elements exceeds the specified limit\n");
- }
- idx=cur_element_count;
- dict_external_to_internal[label] = idx;
- cur_element_count++;
- }
+ idx = cur_element_count;
+ dict_external_to_internal[label] = idx;
+ cur_element_count++;
}
- memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
- memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
+ }
+ memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
+ memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
+ }
+ void removePoint(labeltype cur_external) {
+ size_t cur_c = dict_external_to_internal[cur_external];
+ dict_external_to_internal.erase(cur_external);
- };
+ labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
+ dict_external_to_internal[label] = cur_c;
+ memcpy(data_ + size_per_element_ * cur_c,
+ data_ + size_per_element_ * (cur_element_count-1),
+ data_size_+sizeof(labeltype));
+ cur_element_count--;
+ }
- void removePoint(labeltype cur_external) {
- size_t cur_c=dict_external_to_internal[cur_external];
- dict_external_to_internal.erase(cur_external);
-
- labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
- dict_external_to_internal[label]=cur_c;
- memcpy(data_ + size_per_element_ * cur_c,
- data_ + size_per_element_ * (cur_element_count-1),
- data_size_+sizeof(labeltype));
- cur_element_count--;
-
+ std::priority_queue<std::pair<dist_t, labeltype >>
+ searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
+ assert(k <= cur_element_count);
+ std::priority_queue<std::pair<dist_t, labeltype >> topResults;
+ if (cur_element_count == 0) return topResults;
+ for (int i = 0; i < k; i++) {
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
+ labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
+ }
}
+ dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
+ for (int i = k; i < cur_element_count; i++) {
+ dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
+ if (dist <= lastdist) {
+ labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
+ if ((!isIdAllowed) || (*isIdAllowed)(label)) {
+ topResults.push(std::pair<dist_t, labeltype>(dist, label));
+ }
+ if (topResults.size() > k)
+ topResults.pop();
-
- std::priority_queue<std::pair<dist_t, labeltype >>
- searchKnn(const void *query_data, size_t k) const {
- std::priority_queue<std::pair<dist_t, labeltype >> topResults;
- if (cur_element_count == 0) return topResults;
- for (int i = 0; i < k; i++) {
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
- data_size_))));
- }
- dist_t lastdist = topResults.top().first;
- for (int i = k; i < cur_element_count; i++) {
- dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
- if (dist <= lastdist) {
- topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
- data_size_))));
- if (topResults.size() > k)
- topResults.pop();
+ if (!topResults.empty()) {
lastdist = topResults.top().first;
}
-
}
- return topResults;
- };
-
- void saveIndex(const std::string &location) {
- std::ofstream output(location, std::ios::binary);
- std::streampos position;
-
- writeBinaryPOD(output, maxelements_);
- writeBinaryPOD(output, size_per_element_);
- writeBinaryPOD(output, cur_element_count);
-
- output.write(data_, maxelements_ * size_per_element_);
-
- output.close();
}
+ return topResults;
+ }
- void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
+ void saveIndex(const std::string &location) {
+ std::ofstream output(location, std::ios::binary);
+ std::streampos position;
- std::ifstream input(location, std::ios::binary);
+ writeBinaryPOD(output, maxelements_);
+ writeBinaryPOD(output, size_per_element_);
+ writeBinaryPOD(output, cur_element_count);
- if (!input.is_open())
- throw std::runtime_error("Cannot open file");
+ output.write(data_, maxelements_ * size_per_element_);
- std::streampos position;
+ output.close();
+ }
- readBinaryPOD(input, maxelements_);
- readBinaryPOD(input, size_per_element_);
- readBinaryPOD(input, cur_element_count);
- data_size_ = s->get_data_size();
- fstdistfunc_ = s->get_dist_func();
- dist_func_param_ = s->get_dist_func_param();
- size_per_element_ = data_size_ + sizeof(labeltype);
- data_ = (char *) malloc(maxelements_ * size_per_element_);
- if (data_ == nullptr)
- std::runtime_error("Not enough memory: loadIndex failed to allocate data");
+ void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
+ std::ifstream input(location, std::ios::binary);
+ std::streampos position;
- input.read(data_, maxelements_ * size_per_element_);
+ readBinaryPOD(input, maxelements_);
+ readBinaryPOD(input, size_per_element_);
+ readBinaryPOD(input, cur_element_count);
- input.close();
+ data_size_ = s->get_data_size();
+ fstdistfunc_ = s->get_dist_func();
+ dist_func_param_ = s->get_dist_func_param();
+ size_per_element_ = data_size_ + sizeof(labeltype);
+ data_ = (char *) malloc(maxelements_ * size_per_element_);
+ if (data_ == nullptr)
+ throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
- }
+ input.read(data_, maxelements_ * size_per_element_);
- };
-}
+ input.close();
+ }
+};
+} // namespace hnswlib