lib/hanny/lsh_index.rb in hanny-0.1.0 vs lib/hanny/lsh_index.rb in hanny-0.2.0
- old
+ new
@@ -63,17 +63,12 @@
# Create a new nearest neighbor index.
# @param code_length [Integer] The length of binary code for hash key.
# @param random_seed [Integer/NilClass] The seed value using to initialize the random generator.
def initialize(code_length: 256, random_seed: nil)
@code_length = code_length
- @n_samples = nil
- @n_features = nil
- @n_keys = nil
@last_id = nil
@weight_mat = nil
- @hash_table = nil
- @hash_codes = nil
@random_seed = random_seed
@random_seed ||= srand
@rng = Random.new(@random_seed)
end
@@ -84,30 +79,31 @@
x.dot(@weight_mat).ge(0.0)
end
# Build a search index.
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset for building search index.
- # @return [SVC] The search index itself that has constructed the hash table.
+ # @return [LSHIndex] The search index itself that has constructed the hash table.
def build_index(x)
# Initialize some variables.
- @n_samples, @n_features = x.shape
+ @n_samples = x.shape[0]
+ @n_features = x.shape[1]
@hash_table = {}
- @hash_codes = []
@weight_mat = Utils.rand_normal([@n_features, @code_length], @rng)
# Convert samples to binary codes.
bin_x = hash_function(x)
# Store samples to binary hash table.
+ codes = []
@n_samples.times do |m|
bin_code = bin_x[m, true]
hash_key = symbolized_hash_key(bin_code)
unless @hash_table.key?(hash_key)
- @hash_codes.push(bin_code.to_a)
+ codes.push(bin_code.to_a)
@hash_table[hash_key] = []
end
@hash_table[hash_key].push(m)
end
- @hash_codes = Numo::Bit.cast(@hash_codes)
+ @hash_codes = Numo::Bit.cast(codes)
# Update some variables.
@n_keys = @hash_codes.shape[0]
@last_id = @n_samples
self
end
@@ -115,11 +111,11 @@
# Append new data to the search index.
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The dataset to append to search index.
# @return [Array<Integer>] The indices of appended data in search index
def append_data(x)
# Initialize some variables.
- n_new_samples, = x.shape
+ n_new_samples = x.shape[0]
bin_x = hash_function(x)
added_data_ids = []
# Store samples to binary hash table.
new_codes = []
n_new_samples.times do |m|
@@ -150,16 +146,18 @@
# @return [Array<Integer>] The indices of removed data in search index
def remove_data(data_ids)
removed_data_ids = []
data_ids.each do |query_id|
# Remove data id from hash table.
- hash_key = @hash_table.keys.select { |k| @hash_table[k].include?(query_id) }.first
+ hash_key = @hash_table.keys.find { |k| @hash_table[k].include?(query_id) }
next if hash_key.nil?
+
@hash_table[hash_key].delete(query_id)
removed_data_ids.push(query_id)
# Remove the hash key if there is no data.
next unless @hash_table[hash_key].empty?
+
target_id = distances_to_hash_codes(decoded_hash_key(hash_key)).index(0)
@hash_codes = @hash_codes.delete(target_id, 0)
end
@n_samples -= removed_data_ids.size
removed_data_ids
@@ -169,78 +167,50 @@
# @param q [Numo::DFloat] (shape: [n_queries, n_features]) The data for search queries.
# @param n_neighbors [Integer] The number of neighbors.
# @return [Array<Integer>] The data indices of search result.
def search_knn(q, n_neighbors: 10)
# Initialize some variables.
- n_queries, = q.shape
+ n_queries = q.shape[0]
candidates = Array.new(n_queries) { [] }
# Binarize queries.
bin_q = hash_function(q)
# Find k-nearest neighbors for each query.
n_queries.times do |m|
- sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |_, n|
+ sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |d, n|
candidates[m] = candidates[m] | @hash_table[symbolized_hash_key(@hash_codes[n, true])]
- break if candidates[m].size >= n_neighbors
+ # TODO: Investigate the cause of the steep Ruby::BreakTypeMismatch error.
+ # break if candidates[m].size >= n_neighbors
+ break [[d, n]] if candidates[m].size >= n_neighbors
end
candidates[m] = candidates[m].shift(n_neighbors)
end
candidates
end
# Perform hamming radius nearest neighbor search.
# @param q [Numo::DFloat] (shape: [n_queries, n_features]) The data for search queries.
# @param radius [Float] The hamming radius for search range.
# @return [Array<Integer>] The data indices of search result.
- def search_radius(q, radius: 1)
+ def search_radius(q, radius: 1.0)
# Initialize some variables.
- n_queries, = q.shape
+ n_queries = q.shape[0]
candidates = Array.new(n_queries) { [] }
# Binarize queries.
bin_q = hash_function(q)
# Find k-nearest neighbors for each query.
n_queries.times do |m|
sort_with_index(distances_to_hash_codes(bin_q[m, true])).each do |d, n|
- break if d > radius
+ # TODO: Investigate the cause of the steep Ruby::BreakTypeMismatch error.
+ # break if d > radius
+ break [[d, n]] if d > radius
+
candidates[m] = candidates[m] | @hash_table[symbolized_hash_key(@hash_codes[n, true])]
end
end
candidates
end
- # Dump marshal data.
- # @return [Hash] The marshal data for search index.
- def marshal_dump
- { code_length: @code_length,
- n_samples: @n_samples,
- n_features: @n_features,
- n_keys: @n_keys,
- last_id: @last_id,
- weight_mat: @weight_mat,
- bias_vec: @bias_vec,
- hash_table: @hash_table,
- hash_codes: @hash_codes,
- random_seed: @random_seed,
- rng: @rng }
- end
-
- # Load marshal data.
- # @return [nil]
- def marshal_load(obj)
- @code_length = obj[:code_length]
- @n_samples = obj[:n_samples]
- @n_features = obj[:n_features]
- @n_keys = obj[:n_keys]
- @last_id = obj[:last_id]
- @weight_mat = obj[:weight_mat]
- @bias_vec = obj[:bias_vec]
- @hash_table = obj[:hash_table]
- @hash_codes = obj[:hash_codes]
- @random_seed = obj[:random_seed]
- @rng = obj[:rng]
- nil
- end
-
private
# Convert binary code to symbol as hash key.
# @param bin_code [Numo::Bit]
# @return [Symbol]
@@ -264,10 +234,10 @@
# Convert hash key symbol to binary code.
# @param hash_key [Symbol]
# @return [Numo::Bit]
def decoded_hash_key(hash_key)
- bin_code = Zlib::Inflate.inflate(hash_key.to_s).split('').map(&:to_i)
+ bin_code = Zlib::Inflate.inflate(hash_key.to_s).chars.map(&:to_i)
Numo::Bit[*bin_code]
end
end
end