lib/ngt/index.rb in ngt-0.1.1 vs lib/ngt/index.rb in ngt-0.2.0
- old
+ new
@@ -1,7 +1,9 @@
module Ngt
class Index
+ include Utils
+
def initialize(path)
@path = path
@error = FFI.ngt_create_error_object
@index = ffi(:ngt_open_index, path)
@@ -20,23 +22,30 @@
def insert(object)
ffi(:ngt_insert_index, @index, c_object(object.to_a), @dimension)
end
- # TODO make more performant for Numo
def batch_insert(objects, num_threads: 8)
- objects = objects.to_a
- flat_objects = objects.flatten
- obj = ::FFI::MemoryPointer.new(:float, flat_objects.size)
- obj.write_array_of_float(flat_objects)
+ if narray?(objects)
+ objects = objects.cast_to(Numo::SFloat) unless objects.is_a?(Numo::SFloat)
+ count = objects.shape[0]
+ obj = ::FFI::MemoryPointer.new(:char, objects.byte_size)
+ obj.write_bytes(objects.to_binary)
+ else
+ objects = objects.to_a
+ count = objects.size
+ flat_objects = objects.flatten
+ obj = ::FFI::MemoryPointer.new(:float, flat_objects.size)
+ obj.write_array_of_float(flat_objects)
+ end
- ids = ::FFI::MemoryPointer.new(:uint32, objects.size)
- ffi(:ngt_batch_insert_index, @index, obj, objects.size, ids)
+ ids = ::FFI::MemoryPointer.new(:uint32, count)
+ ffi(:ngt_batch_insert_index, @index, obj, count, ids)
build_index(num_threads: num_threads)
- ids.read_array_of_uint32(objects.size)
+ ids.read_array_of_uint32(count)
end
def build_index(num_threads: 8)
ffi(:ngt_create_index, @index, num_threads)
end
@@ -125,15 +134,12 @@
FFI.ngt_destroy_property(property) if property
FFI.ngt_close_index(index) if index
end
# private
- def self.ffi(method, *args)
- res = FFI.send(method, *args)
- message = FFI.ngt_get_error_string(args.last)
- raise Error, message unless message.empty?
- res
+ def self.ffi(*args)
+ Utils.ffi(*args)
end
def self.finalize(error)
# must use proc instead of stabby lambda
proc do
@@ -142,11 +148,11 @@
end
end
private
- def ffi(*args)
- self.class.ffi(*args, @error)
+ def narray?(data)
+ defined?(Numo::NArray) && data.is_a?(Numo::NArray)
end
def float?
@float
end