lib/ngt/index.rb in ngt-0.2.3 vs lib/ngt/index.rb in ngt-0.2.4

- old
+ new

@@ -1,29 +1,37 @@ module Ngt class Index include Utils + DISTANCE_TYPES = [:l1, :l2, :hamming, :angle, :cosine, :normalized_angle, :normalized_cosine, :jaccard] + + attr_reader :dimensions, :distance_type, :edge_size_for_creation, :edge_size_for_search, :object_type, :path + def initialize(path) @path = path @error = FFI.ngt_create_error_object @index = ffi(:ngt_open_index, path) property = ffi(:ngt_create_property) ffi(:ngt_get_property, @index, property) - @dimension = ffi(:ngt_get_property_dimension, property) + @dimensions = ffi(:ngt_get_property_dimension, property) + @distance_type = DISTANCE_TYPES[ffi(:ngt_get_property_distance_type, property)] + @edge_size_for_creation = ffi(:ngt_get_property_edge_size_for_creation, property) + @edge_size_for_search = ffi(:ngt_get_property_edge_size_for_search, property) object_type = ffi(:ngt_get_property_object_type, property) @float = FFI.ngt_is_property_object_type_float(object_type) + @object_type = @float ? :float : :integer @object_space = ffi(:ngt_get_object_space, @index) ObjectSpace.define_finalizer(self, self.class.finalize(@error)) end def insert(object) - ffi(:ngt_insert_index, @index, c_object(object.to_a), @dimension) + ffi(:ngt_insert_index, @index, c_object(object.to_a), @dimensions) end def batch_insert(objects, num_threads: 8) if narray?(objects) objects = objects.cast_to(Numo::SFloat) unless objects.is_a?(Numo::SFloat) @@ -51,25 +59,25 @@ end def object(id) if float? res = ffi(:ngt_get_object_as_float, @object_space, id) - res.read_array_of_float(@dimension) + res.read_array_of_float(@dimensions) else res = ffi(:ngt_get_object_as_integer, @object_space, id) - res.read_array_of_uint8(@dimension) + res.read_array_of_uint8(@dimensions) end end def remove(id) ffi(:ngt_remove_index, @index, id) end def search(query, size: 20, epsilon: 0.1, radius: nil) radius ||= -1.0 results = ffi(:ngt_create_empty_results) - ffi(:ngt_search_index, @index, c_object(query.to_a), @dimension, size, epsilon, radius, results) + ffi(:ngt_search_index, @index, c_object(query.to_a), @dimensions, size, epsilon, radius, results) result_size = ffi(:ngt_get_result_size, results) ret = [] result_size.times do |i| res = ffi(:ngt_get_result, results, i) ret << { @@ -88,51 +96,51 @@ def close FFI.ngt_close_index(@index) end - def self.new(dimension, path: nil, edge_size_for_creation: 10, - edge_size_for_search: 40, object_type: "Float", distance_type: "L2") + def self.new(dimensions, path: nil, edge_size_for_creation: 10, + edge_size_for_search: 40, object_type: :float, distance_type: :l2) # called from load - return super(path) if path && dimension.nil? + return super(path) if path && dimensions.nil? # TODO remove in 0.3.0 - create = dimension.is_a?(Integer) || path + create = dimensions.is_a?(Integer) || path unless create warn "[ngt] Passing a path to new is deprecated - use load instead" - return super(dimension) + return super(dimensions) end path ||= Dir.mktmpdir error = FFI.ngt_create_error_object property = ffi(:ngt_create_property, error) - ffi(:ngt_set_property_dimension, property, dimension, error) + ffi(:ngt_set_property_dimension, property, dimensions, error) ffi(:ngt_set_property_edge_size_for_creation, property, edge_size_for_creation, error) ffi(:ngt_set_property_edge_size_for_search, property, edge_size_for_search, error) - case object_type.to_s - when "Float", "float" + case object_type.to_s.downcase + when "float" ffi(:ngt_set_property_object_type_float, property, error) - when "Integer", "integer" + when "integer" ffi(:ngt_set_property_object_type_integer, property, error) else raise ArgumentError, "Unknown object type: #{object_type}" end - case distance_type.to_s - when "L1" + case distance_type.to_s.downcase + when "l1" ffi(:ngt_set_property_distance_type_l1, property, error) - when "L2" + when "l2" ffi(:ngt_set_property_distance_type_l2, property, error) - when "Angle" + when "angle" ffi(:ngt_set_property_distance_type_angle, property, error) - when "Hamming" + when "hamming" ffi(:ngt_set_property_distance_type_hamming, property, error) - when "Jaccard" + when "jaccard" ffi(:ngt_set_property_distance_type_jaccard, property, error) - when "Cosine" + when "cosine" ffi(:ngt_set_property_distance_type_cosine, property, error) else raise ArgumentError, "Unknown distance type: #{distance_type}" end @@ -149,12 +157,12 @@ def self.load(path) new(nil, path: path) end - def self.create(path, dimension, **options) + def self.create(path, dimensions, **options) warn "[ngt] create is deprecated - use new instead" - new(dimension, path: path, **options) + new(dimensions, path: path, **options) end # private def self.ffi(*args) Utils.ffi(*args)