#pragma once #include #include #include #include namespace hnswlib { template class BruteforceSearch : public AlgorithmInterface { public: BruteforceSearch() : data_(nullptr) { } BruteforceSearch(SpaceInterface *s) { } BruteforceSearch(SpaceInterface *s, const std::string &location) { loadIndex(location, s); } BruteforceSearch(SpaceInterface *s, size_t maxElements) { maxelements_ = maxElements; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxElements * size_per_element_); if (data_ == nullptr) std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); cur_element_count = 0; } ~BruteforceSearch() { free(data_); } char *data_; size_t maxelements_; size_t cur_element_count; size_t size_per_element_; size_t data_size_; DISTFUNC fstdistfunc_; void *dist_func_param_; std::mutex index_lock; std::unordered_map dict_external_to_internal; void addPoint(const void *datapoint, labeltype label) { int idx; { std::unique_lock lock(index_lock); auto search=dict_external_to_internal.find(label); if (search != dict_external_to_internal.end()) { idx=search->second; } else{ if (cur_element_count >= maxelements_) { throw std::runtime_error("The number of elements exceeds the specified limit\n"); } idx=cur_element_count; dict_external_to_internal[label] = idx; cur_element_count++; } } memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); }; void removePoint(labeltype cur_external) { size_t cur_c=dict_external_to_internal[cur_external]; dict_external_to_internal.erase(cur_external); labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); dict_external_to_internal[label]=cur_c; memcpy(data_ + size_per_element_ * cur_c, data_ + size_per_element_ * (cur_element_count-1), data_size_+sizeof(labeltype)); cur_element_count--; } std::priority_queue> searchKnn(const void *query_data, size_t k) const { std::priority_queue> topResults; if (cur_element_count == 0) return topResults; for (size_t i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + data_size_)))); } dist_t lastdist = topResults.top().first; for (size_t i = k; i < cur_element_count; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + data_size_)))); if (topResults.size() > k) topResults.pop(); lastdist = topResults.top().first; } } return topResults; }; void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); std::streampos position; writeBinaryPOD(output, maxelements_); writeBinaryPOD(output, size_per_element_); writeBinaryPOD(output, cur_element_count); output.write(data_, maxelements_ * size_per_element_); output.close(); } void loadIndex(const std::string &location, SpaceInterface *s) { std::ifstream input(location, std::ios::binary); std::streampos position; readBinaryPOD(input, maxelements_); readBinaryPOD(input, size_per_element_); readBinaryPOD(input, cur_element_count); data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); size_per_element_ = data_size_ + sizeof(labeltype); data_ = (char *) malloc(maxelements_ * size_per_element_); if (data_ == nullptr) std::runtime_error("Not enough memory: loadIndex failed to allocate data"); input.read(data_, maxelements_ * size_per_element_); input.close(); } }; }