lib/hanny/lsh_index.rb in hanny-0.1.0 vs lib/hanny/lsh_index.rb in hanny-0.2.0

- old
+ new

@@ -63,17 +63,12 @@ # Create a new nearest neighbor index. # @param code_length [Integer] The length of binary code for hash key. # @param random_seed [Integer/NilClass] The seed value using to initialize the random generator. def initialize(code_length: 256, random_seed: nil) @code_length = code_length - @n_samples = nil - @n_features = nil - @n_keys = nil @last_id = nil @weight_mat = nil - @hash_table = nil - @hash_codes = nil @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) end @@ -84,30 +79,31 @@ x.dot(@weight_mat).ge(0.0) end # Build a search index. # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset for building search index. - # @return [SVC] The search index itself that has constructed the hash table. + # @return [LSHIndex] The search index itself that has constructed the hash table. def build_index(x) # Initialize some variables. - @n_samples, @n_features = x.shape + @n_samples = x.shape[0] + @n_features = x.shape[1] @hash_table = {} - @hash_codes = [] @weight_mat = Utils.rand_normal([@n_features, @code_length], @rng) # Convert samples to binary codes. bin_x = hash_function(x) # Store samples to binary hash table. + codes = [] @n_samples.times do |m| bin_code = bin_x[m, true] hash_key = symbolized_hash_key(bin_code) unless @hash_table.key?(hash_key) - @hash_codes.push(bin_code.to_a) + codes.push(bin_code.to_a) @hash_table[hash_key] = [] end @hash_table[hash_key].push(m) end - @hash_codes = Numo::Bit.cast(@hash_codes) + @hash_codes = Numo::Bit.cast(codes) # Update some variables. @n_keys = @hash_codes.shape[0] @last_id = @n_samples self end @@ -115,11 +111,11 @@ # Append new data to the search index. # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to append to search index. # @return [Array<Integer>] The indices of appended data in search index def append_data(x) # Initialize some variables. - n_new_samples, = x.shape + n_new_samples = x.shape[0] bin_x = hash_function(x) added_data_ids = [] # Store samples to binary hash table. new_codes = [] n_new_samples.times do |m| @@ -150,16 +146,18 @@ # @return [Array<Integer>] The indices of removed data in search index def remove_data(data_ids) removed_data_ids = [] data_ids.each do |query_id| # Remove data id from hash table. - hash_key = @hash_table.keys.select { |k| @hash_table[k].include?(query_id) }.first + hash_key = @hash_table.keys.find { |k| @hash_table[k].include?(query_id) } next if hash_key.nil? + @hash_table[hash_key].delete(query_id) removed_data_ids.push(query_id) # Remove the hash key if there is no data. next unless @hash_table[hash_key].empty? + target_id = distances_to_hash_codes(decoded_hash_key(hash_key)).index(0) @hash_codes = @hash_codes.delete(target_id, 0) end @n_samples -= removed_data_ids.size removed_data_ids @@ -169,78 +167,50 @@ # @param q [Numo::DFloat] (shape: [n_queries, n_features]) The data for search queries. # @param n_neighbors [Integer] The number of neighbors. # @return [Array<Integer>] The data indices of search result. def search_knn(q, n_neighbors: 10) # Initialize some variables. - n_queries, = q.shape + n_queries = q.shape[0] candidates = Array.new(n_queries) { [] } # Binarize queries. bin_q = hash_function(q) # Find k-nearest neighbors for each query. n_queries.times do |m| - sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |_, n| + sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |d, n| candidates[m] = candidates[m] | @hash_table[symbolized_hash_key(@hash_codes[n, true])] - break if candidates[m].size >= n_neighbors + # TODO: Investigate the cause of the steep Ruby::BreakTypeMismatch error. + # break if candidates[m].size >= n_neighbors + break [[d, n]] if candidates[m].size >= n_neighbors end candidates[m] = candidates[m].shift(n_neighbors) end candidates end # Perform hamming radius nearest neighbor search. # @param q [Numo::DFloat] (shape: [n_queries, n_features]) The data for search queries. # @param radius [Float] The hamming radius for search range. # @return [Array<Integer>] The data indices of search result. - def search_radius(q, radius: 1) + def search_radius(q, radius: 1.0) # Initialize some variables. - n_queries, = q.shape + n_queries = q.shape[0] candidates = Array.new(n_queries) { [] } # Binarize queries. bin_q = hash_function(q) # Find k-nearest neighbors for each query. n_queries.times do |m| sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |d, n| - break if d > radius + # TODO: Investigate the cause of the steep Ruby::BreakTypeMismatch error. + # break if d > radius + break [[d, n]] if d > radius + candidates[m] = candidates[m] | @hash_table[symbolized_hash_key(@hash_codes[n, true])] end end candidates end - # Dump marshal data. - # @return [Hash] The marshal data for search index. - def marshal_dump - { code_length: @code_length, - n_samples: @n_samples, - n_features: @n_features, - n_keys: @n_keys, - last_id: @last_id, - weight_mat: @weight_mat, - bias_vec: @bias_vec, - hash_table: @hash_table, - hash_codes: @hash_codes, - random_seed: @random_seed, - rng: @rng } - end - - # Load marshal data. - # @return [nil] - def marshal_load(obj) - @code_length = obj[:code_length] - @n_samples = obj[:n_samples] - @n_features = obj[:n_features] - @n_keys = obj[:n_keys] - @last_id = obj[:last_id] - @weight_mat = obj[:weight_mat] - @bias_vec = obj[:bias_vec] - @hash_table = obj[:hash_table] - @hash_codes = obj[:hash_codes] - @random_seed = obj[:random_seed] - @rng = obj[:rng] - nil - end - private # Convert binary code to symbol as hash key. # @param bin_code [Numo::Bit] # @return [Symbol] @@ -264,10 +234,10 @@ # Convert hash key symbol to binary code. # @param hash_key [Symbol] # @return [Numo::Bit] def decoded_hash_key(hash_key) - bin_code = Zlib::Inflate.inflate(hash_key.to_s).split('').map(&:to_i) + bin_code = Zlib::Inflate.inflate(hash_key.to_s).chars.map(&:to_i) Numo::Bit[*bin_code] end end end