ext/hnswlib/hnswlibext.hpp in hnswlib-0.1.1 vs ext/hnswlib/hnswlibext.hpp in hnswlib-0.2.0

- old
+ new

@@ -24,15 +24,17 @@ #include <hnswlib.h> VALUE rb_cHnswlibL2Space; VALUE rb_cHnswlibInnerProductSpace; VALUE rb_cHnswlibHierarchicalNSW; +VALUE rb_cHnswlibBruteforceSearch; class RbHnswlibL2Space { public: static VALUE hnsw_l2space_alloc(VALUE self) { hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space)); + new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC. return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr); }; static void hnsw_l2space_free(void* ptr) { ((hnswlib::L2Space*)ptr)->~L2Space(); @@ -104,10 +106,11 @@ class RbHnswlibInnerProductSpace { public: static VALUE hnsw_ipspace_alloc(VALUE self) { hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace)); + new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC. return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr); }; static void hnsw_ipspace_free(void* ptr) { ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace(); @@ -179,10 +182,11 @@ class RbHnswlibHierarchicalNSW { public: static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) { hnswlib::HierarchicalNSW<float>* ptr = (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>)); + new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC. return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr); }; static void hnsw_hierarchicalnsw_free(void* ptr) { ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW(); @@ -237,10 +241,31 @@ rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values); if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16); if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200); if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100); + if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) { + rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace"); + return Qnil; + } + if (!RB_INTEGER_TYPE_P(kw_values[1])) { + rb_raise(rb_eTypeError, "expected max_elements, Integer"); + return Qnil; + } + if (!RB_INTEGER_TYPE_P(kw_values[2])) { + rb_raise(rb_eTypeError, "expected m, Integer"); + return Qnil; + } + if (!RB_INTEGER_TYPE_P(kw_values[3])) { + rb_raise(rb_eTypeError, "expected ef_construction, Integer"); + return Qnil; + } + if (!RB_INTEGER_TYPE_P(kw_values[4])) { + rb_raise(rb_eTypeError, "expected random_seed, Integer"); + return Qnil; + } + rb_iv_set(self, "@space", kw_values[0]); hnswlib::SpaceInterface<float>* space; if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) { space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]); } else { @@ -410,9 +435,206 @@ "RbHnswlibHierarchicalNSW", { NULL, RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_free, RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_size + }, + NULL, + NULL, + RUBY_TYPED_FREE_IMMEDIATELY +}; + +class RbHnswlibBruteforceSearch { + public: + static VALUE hnsw_bruteforcesearch_alloc(VALUE self) { + hnswlib::BruteforceSearch<float>* ptr = (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>)); + new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC. + return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr); + }; + + static void hnsw_bruteforcesearch_free(void* ptr) { + ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch(); + ruby_xfree(ptr); + }; + + static size_t hnsw_bruteforcesearch_size(const void* ptr) { + return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); + }; + + static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) { + hnswlib::BruteforceSearch<float>* ptr; + TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr); + return ptr; + }; + + static VALUE define_class(VALUE rb_mHnswlib) { + rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject); + rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc); + rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1); + rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2); + rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2); + rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1); + rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1); + rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1); + rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0); + rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0); + rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0); + return rb_cHnswlibBruteforceSearch; + }; + + private: + static const rb_data_type_t hnsw_bruteforcesearch_type; + + static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) { + VALUE kw_args = Qnil; + ID kw_table[2] = { rb_intern("space"), rb_intern("max_elements") }; + VALUE kw_values[2] = { Qundef, Qundef }; + rb_scan_args(argc, argv, ":", &kw_args); + rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values); + + if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) { + rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace"); + return Qnil; + } + if (!RB_INTEGER_TYPE_P(kw_values[1])) { + rb_raise(rb_eTypeError, "expected max_elements, Integer"); + return Qnil; + } + + rb_iv_set(self, "@space", kw_values[0]); + hnswlib::SpaceInterface<float>* space; + if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) { + space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]); + } else { + space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]); + } + const size_t max_elements = (size_t)NUM2INT(kw_values[1]); + + hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self); + new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements); + + return Qnil; + }; + + static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) { + const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim")); + + if (!RB_TYPE_P(arr, T_ARRAY)) { + rb_raise(rb_eArgError, "Expect point vector to be Ruby Array."); + return Qfalse; + } + + if (!RB_INTEGER_TYPE_P(idx)) { + rb_raise(rb_eArgError, "Expect index to be Ruby Integer."); + return Qfalse; + } + + if (dim != RARRAY_LEN(arr)) { + rb_raise(rb_eArgError, "Array size does not match to index dimensionality."); + return Qfalse; + } + + float* vec = (float*)ruby_xmalloc(dim * sizeof(float)); + for (int i = 0; i < dim; i++) { + vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i)); + } + + get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx)); + + ruby_xfree(vec); + return Qtrue; + }; + + static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) { + const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim")); + + if (!RB_TYPE_P(arr, T_ARRAY)) { + rb_raise(rb_eArgError, "Expect query vector to be Ruby Array."); + return Qnil; + } + + if (!RB_INTEGER_TYPE_P(k)) { + rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer."); + return Qnil; + } + + if (dim != RARRAY_LEN(arr)) { + rb_raise(rb_eArgError, "Array size does not match to index dimensionality."); + return Qnil; + } + + float* vec = (float*)ruby_xmalloc(dim * sizeof(float)); + for (int i = 0; i < dim; i++) { + vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i)); + } + + std::priority_queue<std::pair<float, size_t>> result = + get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k)); + + ruby_xfree(vec); + + if (result.size() != (size_t)NUM2INT(k)) { + rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small."); + return Qnil; + } + + VALUE distances_arr = rb_ary_new2(result.size()); + VALUE neighbors_arr = rb_ary_new2(result.size()); + + for (int i = NUM2INT(k) - 1; i >= 0; i--) { + const std::pair<float, size_t>& result_tuple = result.top(); + rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first)); + rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second)); + result.pop(); + } + + VALUE ret = rb_ary_new2(2); + rb_ary_store(ret, 0, neighbors_arr); + rb_ary_store(ret, 1, distances_arr); + return ret; + }; + + static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) { + std::string filename(StringValuePtr(_filename)); + get_hnsw_bruteforcesearch(self)->saveIndex(filename); + RB_GC_GUARD(_filename); + return Qnil; + }; + + static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) { + std::string filename(StringValuePtr(_filename)); + VALUE ivspace = rb_iv_get(self, "@space"); + hnswlib::SpaceInterface<float>* space; + if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) { + space = RbHnswlibL2Space::get_hnsw_l2space(ivspace); + } else { + space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace); + } + get_hnsw_bruteforcesearch(self)->loadIndex(filename, space); + RB_GC_GUARD(_filename); + return Qnil; + }; + + static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) { + get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx)); + return Qnil; + }; + + static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) { + return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_)); + }; + + static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) { + return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count)); + }; +}; + +const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = { + "RbHnswlibBruteforceSearch", + { + NULL, + RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_free, + RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_size }, NULL, NULL, RUBY_TYPED_FREE_IMMEDIATELY };