lib/ngt/index.rb in ngt-0.4.0 vs lib/ngt/index.rb in ngt-0.4.1
- old
+ new
@@ -32,26 +32,38 @@
end
def object_type
@object_type ||= begin
object_type = ffi(:ngt_get_property_object_type, @property)
- FFI.ngt_is_property_object_type_float(object_type) ? :float : :integer
+ if FFI.ngt_is_property_object_type_float(object_type)
+ :float
+ elsif FFI.ngt_is_property_object_type_float16(object_type)
+ :float16
+ else
+ :integer
+ end
end
end
def insert(object)
- ffi(:ngt_insert_index, @index, c_object(object.to_a), dimensions)
+ object = object.to_a
+ ffi(:ngt_insert_index, @index, c_object(object), object.size)
end
def batch_insert(objects, num_threads: 8)
if narray?(objects)
+ check_dimensions(objects.shape[1])
+
objects = objects.cast_to(Numo::SFloat) unless objects.is_a?(Numo::SFloat)
count = objects.shape[0]
obj = ::FFI::MemoryPointer.new(:char, objects.byte_size)
obj.write_bytes(objects.to_binary)
else
objects = objects.to_a
+ objects.each do |object|
+ check_dimensions(object.size)
+ end
count = objects.size
flat_objects = objects.flatten
obj = ::FFI::MemoryPointer.new(:float, flat_objects.size)
obj.write_array_of_float(flat_objects)
end
@@ -68,26 +80,29 @@
ffi(:ngt_create_index, @index, num_threads)
end
def object(id)
if object_type == :float
- res = ffi(:ngt_get_object_as_float, @object_space, id)
+ res = ffi(:ngt_get_object_as_float, object_space, id)
res.read_array_of_float(dimensions)
- else
- res = ffi(:ngt_get_object_as_integer, @object_space, id)
+ elsif object_type == :integer
+ res = ffi(:ngt_get_object_as_integer, object_space, id)
res.read_array_of_uint8(dimensions)
+ else
+ raise Error, "Method not supported for this object type"
end
end
def remove(id)
ffi(:ngt_remove_index, @index, id)
end
def search(query, size: 20, epsilon: 0.1, radius: nil)
radius ||= -1.0
results = ffi(:ngt_create_empty_results)
- ffi(:ngt_search_index, @index, c_object(query.to_a), dimensions, size, epsilon, radius, results)
+ query = query.to_a
+ ffi(:ngt_search_index, @index, c_object(query), query.size, size, epsilon, radius, results)
result_size = ffi(:ngt_get_result_size, results)
ret = []
result_size.times do |i|
res = ffi(:ngt_get_result, results, i)
ret << {
@@ -123,10 +138,12 @@
ffi(:ngt_set_property_edge_size_for_search, property, edge_size_for_search, error)
case object_type.to_s.downcase
when "float"
ffi(:ngt_set_property_object_type_float, property, error)
+ when "float16"
+ ffi(:ngt_set_property_object_type_float16, property, error)
when "integer"
ffi(:ngt_set_property_object_type_integer, property, error)
else
raise ArgumentError, "Unknown object type: #{object_type}"
end
@@ -194,11 +211,20 @@
def narray?(data)
defined?(Numo::NArray) && data.is_a?(Numo::NArray)
end
def c_object(object)
+ check_dimensions(object.size)
c_object = ::FFI::MemoryPointer.new(:double, object.size)
c_object.write_array_of_double(object)
c_object
+ end
+
+ def check_dimensions(d)
+ raise ArgumentError, "Bad dimensions" if d != dimensions
+ end
+
+ def object_space
+ @object_space ||= ffi(:ngt_get_object_space, @index)
end
end
end