lib/ngt/index.rb in ngt-0.1.1 vs lib/ngt/index.rb in ngt-0.2.0

- old
+ new

@@ -1,7 +1,9 @@ module Ngt class Index + include Utils + def initialize(path) @path = path @error = FFI.ngt_create_error_object @index = ffi(:ngt_open_index, path) @@ -20,23 +22,30 @@ def insert(object) ffi(:ngt_insert_index, @index, c_object(object.to_a), @dimension) end - # TODO make more performant for Numo def batch_insert(objects, num_threads: 8) - objects = objects.to_a - flat_objects = objects.flatten - obj = ::FFI::MemoryPointer.new(:float, flat_objects.size) - obj.write_array_of_float(flat_objects) + if narray?(objects) + objects = objects.cast_to(Numo::SFloat) unless objects.is_a?(Numo::SFloat) + count = objects.shape[0] + obj = ::FFI::MemoryPointer.new(:char, objects.byte_size) + obj.write_bytes(objects.to_binary) + else + objects = objects.to_a + count = objects.size + flat_objects = objects.flatten + obj = ::FFI::MemoryPointer.new(:float, flat_objects.size) + obj.write_array_of_float(flat_objects) + end - ids = ::FFI::MemoryPointer.new(:uint32, objects.size) - ffi(:ngt_batch_insert_index, @index, obj, objects.size, ids) + ids = ::FFI::MemoryPointer.new(:uint32, count) + ffi(:ngt_batch_insert_index, @index, obj, count, ids) build_index(num_threads: num_threads) - ids.read_array_of_uint32(objects.size) + ids.read_array_of_uint32(count) end def build_index(num_threads: 8) ffi(:ngt_create_index, @index, num_threads) end @@ -125,15 +134,12 @@ FFI.ngt_destroy_property(property) if property FFI.ngt_close_index(index) if index end # private - def self.ffi(method, *args) - res = FFI.send(method, *args) - message = FFI.ngt_get_error_string(args.last) - raise Error, message unless message.empty? - res + def self.ffi(*args) + Utils.ffi(*args) end def self.finalize(error) # must use proc instead of stabby lambda proc do @@ -142,11 +148,11 @@ end end private - def ffi(*args) - self.class.ffi(*args, @error) + def narray?(data) + defined?(Numo::NArray) && data.is_a?(Numo::NArray) end def float? @float end