ext/hnswlib/hnswlibext.hpp in hnswlib-0.5.2 vs ext/hnswlib/hnswlibext.hpp in hnswlib-0.5.3

- old
+ new

@@ -275,11 +275,16 @@ const size_t m = (size_t)NUM2INT(kw_values[2]); const size_t ef_construction = (size_t)NUM2INT(kw_values[3]); const size_t random_seed = (size_t)NUM2INT(kw_values[4]); hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self); - new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed); + try { + new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } return Qnil; }; static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) { @@ -332,12 +337,18 @@ float* vec = (float*)ruby_xmalloc(dim * sizeof(float)); for (int i = 0; i < dim; i++) { vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i)); } - std::priority_queue<std::pair<float, size_t>> result = - get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k)); + std::priority_queue<std::pair<float, size_t>> result; + try { + result = get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k)); + } catch(const std::runtime_error& e) { + ruby_xfree(vec); + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } ruby_xfree(vec); if (result.size() != (size_t)NUM2INT(k)) { rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small."); @@ -374,11 +385,25 @@ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) { space = RbHnswlibL2Space::get_hnsw_l2space(ivspace); } else { space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace); } - get_hnsw_hierarchicalnsw(self)->loadIndex(filename, space); + hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self); + if (index->data_level0_memory_) free(index->data_level0_memory_); + if (index->linkLists_) { + for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) { + if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]); + } + free(index->linkLists_); + } + if (index->visited_list_pool_) delete index->visited_list_pool_; + try { + index->loadIndex(filename, space); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } RB_GC_GUARD(_filename); return Qnil; }; static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) { @@ -387,12 +412,13 @@ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx)); ret = rb_ary_new2(vec.size()); for (size_t i = 0; i < vec.size(); i++) { rb_ary_store(ret, i, DBL2NUM((double)vec[i])); } - } catch(std::runtime_error const& e) { + } catch(const std::runtime_error& e) { rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; } return ret; }; static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) { @@ -402,20 +428,30 @@ } return ret; }; static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) { - get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx)); + try { + get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx)); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } return Qnil; }; static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) { if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) { rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements."); return Qnil; } - get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements)); + try { + get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements)); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } return Qnil; }; static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) { get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef); @@ -508,11 +544,16 @@ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]); } const size_t max_elements = (size_t)NUM2INT(kw_values[1]); hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self); - new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements); + try { + new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } return Qnil; }; static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) { @@ -536,11 +577,17 @@ float* vec = (float*)ruby_xmalloc(dim * sizeof(float)); for (int i = 0; i < dim; i++) { vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i)); } - get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx)); + try { + get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx)); + } catch(const std::runtime_error& e) { + ruby_xfree(vec); + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qfalse; + } ruby_xfree(vec); return Qtrue; }; @@ -607,10 +654,17 @@ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) { space = RbHnswlibL2Space::get_hnsw_l2space(ivspace); } else { space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace); } - get_hnsw_bruteforcesearch(self)->loadIndex(filename, space); + hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self); + if (index->data_) free(index->data_); + try { + index->loadIndex(filename, space); + } catch(const std::runtime_error& e) { + rb_raise(rb_eRuntimeError, "%s", e.what()); + return Qnil; + } RB_GC_GUARD(_filename); return Qnil; }; static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {