ext/hnswlib/hnswlibext.hpp in hnswlib-0.5.2 vs ext/hnswlib/hnswlibext.hpp in hnswlib-0.5.3
- old
+ new
@@ -275,11 +275,16 @@
const size_t m = (size_t)NUM2INT(kw_values[2]);
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
+ try {
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
return Qnil;
};
static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
@@ -332,12 +337,18 @@
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
for (int i = 0; i < dim; i++) {
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
}
- std::priority_queue<std::pair<float, size_t>> result =
- get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
+ std::priority_queue<std::pair<float, size_t>> result;
+ try {
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
+ } catch(const std::runtime_error& e) {
+ ruby_xfree(vec);
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
ruby_xfree(vec);
if (result.size() != (size_t)NUM2INT(k)) {
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
@@ -374,11 +385,25 @@
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
} else {
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
}
- get_hnsw_hierarchicalnsw(self)->loadIndex(filename, space);
+ hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
+ if (index->data_level0_memory_) free(index->data_level0_memory_);
+ if (index->linkLists_) {
+ for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
+ if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
+ }
+ free(index->linkLists_);
+ }
+ if (index->visited_list_pool_) delete index->visited_list_pool_;
+ try {
+ index->loadIndex(filename, space);
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
RB_GC_GUARD(_filename);
return Qnil;
};
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
@@ -387,12 +412,13 @@
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
ret = rb_ary_new2(vec.size());
for (size_t i = 0; i < vec.size(); i++) {
rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
}
- } catch(std::runtime_error const& e) {
+ } catch(const std::runtime_error& e) {
rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
}
return ret;
};
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
@@ -402,20 +428,30 @@
}
return ret;
};
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
+ try {
+ get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
return Qnil;
};
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
return Qnil;
}
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
+ try {
+ get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
return Qnil;
};
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
@@ -508,11 +544,16 @@
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
}
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
- new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
+ try {
+ new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
return Qnil;
};
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
@@ -536,11 +577,17 @@
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
for (int i = 0; i < dim; i++) {
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
}
- get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
+ try {
+ get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
+ } catch(const std::runtime_error& e) {
+ ruby_xfree(vec);
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qfalse;
+ }
ruby_xfree(vec);
return Qtrue;
};
@@ -607,10 +654,17 @@
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
} else {
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
}
- get_hnsw_bruteforcesearch(self)->loadIndex(filename, space);
+ hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
+ if (index->data_) free(index->data_);
+ try {
+ index->loadIndex(filename, space);
+ } catch(const std::runtime_error& e) {
+ rb_raise(rb_eRuntimeError, "%s", e.what());
+ return Qnil;
+ }
RB_GC_GUARD(_filename);
return Qnil;
};
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {