lib/ngt/index.rb in ngt-0.4.0 vs lib/ngt/index.rb in ngt-0.4.1

- old
+ new

@@ -32,26 +32,38 @@ end def object_type @object_type ||= begin object_type = ffi(:ngt_get_property_object_type, @property) - FFI.ngt_is_property_object_type_float(object_type) ? :float : :integer + if FFI.ngt_is_property_object_type_float(object_type) + :float + elsif FFI.ngt_is_property_object_type_float16(object_type) + :float16 + else + :integer + end end end def insert(object) - ffi(:ngt_insert_index, @index, c_object(object.to_a), dimensions) + object = object.to_a + ffi(:ngt_insert_index, @index, c_object(object), object.size) end def batch_insert(objects, num_threads: 8) if narray?(objects) + check_dimensions(objects.shape[1]) + objects = objects.cast_to(Numo::SFloat) unless objects.is_a?(Numo::SFloat) count = objects.shape[0] obj = ::FFI::MemoryPointer.new(:char, objects.byte_size) obj.write_bytes(objects.to_binary) else objects = objects.to_a + objects.each do |object| + check_dimensions(object.size) + end count = objects.size flat_objects = objects.flatten obj = ::FFI::MemoryPointer.new(:float, flat_objects.size) obj.write_array_of_float(flat_objects) end @@ -68,26 +80,29 @@ ffi(:ngt_create_index, @index, num_threads) end def object(id) if object_type == :float - res = ffi(:ngt_get_object_as_float, @object_space, id) + res = ffi(:ngt_get_object_as_float, object_space, id) res.read_array_of_float(dimensions) - else - res = ffi(:ngt_get_object_as_integer, @object_space, id) + elsif object_type == :integer + res = ffi(:ngt_get_object_as_integer, object_space, id) res.read_array_of_uint8(dimensions) + else + raise Error, "Method not supported for this object type" end end def remove(id) ffi(:ngt_remove_index, @index, id) end def search(query, size: 20, epsilon: 0.1, radius: nil) radius ||= -1.0 results = ffi(:ngt_create_empty_results) - ffi(:ngt_search_index, @index, c_object(query.to_a), dimensions, size, epsilon, radius, results) + query = query.to_a + ffi(:ngt_search_index, @index, c_object(query), query.size, size, epsilon, radius, results) result_size = ffi(:ngt_get_result_size, results) ret = [] result_size.times do |i| res = ffi(:ngt_get_result, results, i) ret << { @@ -123,10 +138,12 @@ ffi(:ngt_set_property_edge_size_for_search, property, edge_size_for_search, error) case object_type.to_s.downcase when "float" ffi(:ngt_set_property_object_type_float, property, error) + when "float16" + ffi(:ngt_set_property_object_type_float16, property, error) when "integer" ffi(:ngt_set_property_object_type_integer, property, error) else raise ArgumentError, "Unknown object type: #{object_type}" end @@ -194,11 +211,20 @@ def narray?(data) defined?(Numo::NArray) && data.is_a?(Numo::NArray) end def c_object(object) + check_dimensions(object.size) c_object = ::FFI::MemoryPointer.new(:double, object.size) c_object.write_array_of_double(object) c_object + end + + def check_dimensions(d) + raise ArgumentError, "Bad dimensions" if d != dimensions + end + + def object_space + @object_space ||= ffi(:ngt_get_object_space, @index) end end end