ext/hnswlib/hnswlibext.hpp in hnswlib-0.1.1 vs ext/hnswlib/hnswlibext.hpp in hnswlib-0.2.0
- old
+ new
@@ -24,15 +24,17 @@
#include <hnswlib.h>
VALUE rb_cHnswlibL2Space;
VALUE rb_cHnswlibInnerProductSpace;
VALUE rb_cHnswlibHierarchicalNSW;
+VALUE rb_cHnswlibBruteforceSearch;
class RbHnswlibL2Space {
public:
static VALUE hnsw_l2space_alloc(VALUE self) {
hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
+ new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
};
static void hnsw_l2space_free(void* ptr) {
((hnswlib::L2Space*)ptr)->~L2Space();
@@ -104,10 +106,11 @@
class RbHnswlibInnerProductSpace {
public:
static VALUE hnsw_ipspace_alloc(VALUE self) {
hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
+ new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
};
static void hnsw_ipspace_free(void* ptr) {
((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
@@ -179,10 +182,11 @@
class RbHnswlibHierarchicalNSW {
public:
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
hnswlib::HierarchicalNSW<float>* ptr = (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
+ new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
};
static void hnsw_hierarchicalnsw_free(void* ptr) {
((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
@@ -237,10 +241,31 @@
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
+ return Qnil;
+ }
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
+ return Qnil;
+ }
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
+ rb_raise(rb_eTypeError, "expected m, Integer");
+ return Qnil;
+ }
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
+ rb_raise(rb_eTypeError, "expected ef_construction, Integer");
+ return Qnil;
+ }
+ if (!RB_INTEGER_TYPE_P(kw_values[4])) {
+ rb_raise(rb_eTypeError, "expected random_seed, Integer");
+ return Qnil;
+ }
+
rb_iv_set(self, "@space", kw_values[0]);
hnswlib::SpaceInterface<float>* space;
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
} else {
@@ -410,9 +435,206 @@
"RbHnswlibHierarchicalNSW",
{
NULL,
RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_free,
RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_size
+ },
+ NULL,
+ NULL,
+ RUBY_TYPED_FREE_IMMEDIATELY
+};
+
+class RbHnswlibBruteforceSearch {
+ public:
+ static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
+ hnswlib::BruteforceSearch<float>* ptr = (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
+ new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
+ return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
+ };
+
+ static void hnsw_bruteforcesearch_free(void* ptr) {
+ ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
+ ruby_xfree(ptr);
+ };
+
+ static size_t hnsw_bruteforcesearch_size(const void* ptr) {
+ return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr));
+ };
+
+ static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
+ hnswlib::BruteforceSearch<float>* ptr;
+ TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
+ return ptr;
+ };
+
+ static VALUE define_class(VALUE rb_mHnswlib) {
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
+ rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
+ rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
+ rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
+ return rb_cHnswlibBruteforceSearch;
+ };
+
+ private:
+ static const rb_data_type_t hnsw_bruteforcesearch_type;
+
+ static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
+ VALUE kw_args = Qnil;
+ ID kw_table[2] = { rb_intern("space"), rb_intern("max_elements") };
+ VALUE kw_values[2] = { Qundef, Qundef };
+ rb_scan_args(argc, argv, ":", &kw_args);
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
+
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
+ return Qnil;
+ }
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
+ return Qnil;
+ }
+
+ rb_iv_set(self, "@space", kw_values[0]);
+ hnswlib::SpaceInterface<float>* space;
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
+ } else {
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
+ }
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
+
+ hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
+ new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
+
+ return Qnil;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
+
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
+ return Qfalse;
+ }
+
+ if (!RB_INTEGER_TYPE_P(idx)) {
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
+ return Qfalse;
+ }
+
+ if (dim != RARRAY_LEN(arr)) {
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
+ return Qfalse;
+ }
+
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
+ for (int i = 0; i < dim; i++) {
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
+ }
+
+ get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
+
+ ruby_xfree(vec);
+ return Qtrue;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
+
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
+ return Qnil;
+ }
+
+ if (!RB_INTEGER_TYPE_P(k)) {
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
+ return Qnil;
+ }
+
+ if (dim != RARRAY_LEN(arr)) {
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
+ return Qnil;
+ }
+
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
+ for (int i = 0; i < dim; i++) {
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
+ }
+
+ std::priority_queue<std::pair<float, size_t>> result =
+ get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
+
+ ruby_xfree(vec);
+
+ if (result.size() != (size_t)NUM2INT(k)) {
+ rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
+ return Qnil;
+ }
+
+ VALUE distances_arr = rb_ary_new2(result.size());
+ VALUE neighbors_arr = rb_ary_new2(result.size());
+
+ for (int i = NUM2INT(k) - 1; i >= 0; i--) {
+ const std::pair<float, size_t>& result_tuple = result.top();
+ rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
+ rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
+ result.pop();
+ }
+
+ VALUE ret = rb_ary_new2(2);
+ rb_ary_store(ret, 0, neighbors_arr);
+ rb_ary_store(ret, 1, distances_arr);
+ return ret;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
+ std::string filename(StringValuePtr(_filename));
+ get_hnsw_bruteforcesearch(self)->saveIndex(filename);
+ RB_GC_GUARD(_filename);
+ return Qnil;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
+ std::string filename(StringValuePtr(_filename));
+ VALUE ivspace = rb_iv_get(self, "@space");
+ hnswlib::SpaceInterface<float>* space;
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
+ } else {
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
+ }
+ get_hnsw_bruteforcesearch(self)->loadIndex(filename, space);
+ RB_GC_GUARD(_filename);
+ return Qnil;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
+ get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
+ return Qnil;
+ };
+
+ static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
+ };
+
+ static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
+ };
+};
+
+const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
+ "RbHnswlibBruteforceSearch",
+ {
+ NULL,
+ RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_free,
+ RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_size
},
NULL,
NULL,
RUBY_TYPED_FREE_IMMEDIATELY
};