lib/disco/recommender.rb in disco-0.2.6 vs lib/disco/recommender.rb in disco-0.2.7

- old
+ new

@@ -20,10 +20,14 @@ # TODO option to set in initializer to avoid pass # could also just check first few values # but may be confusing if they are all missing and later ones aren't @implicit = !train_set.any? { |v| v[:rating] } + if @implicit && train_set.any? { |v| v[:value] } + warn "[disco] WARNING: Passing `:value` with implicit feedback has no effect on recommendations and can be removed. Earlier versions of the library incorrectly stated this was used." + end + # TODO improve performance # (catch exception instead of checking ahead of time) unless @implicit check_ratings(train_set) @@ -32,19 +36,18 @@ end end @rated = Hash.new { |hash, key| hash[key] = {} } input = [] - value_key = @implicit ? :value : :rating train_set.each do |v| # update maps and build matrix in single pass u = (@user_map[v[:user_id]] ||= @user_map.size) i = (@item_map[v[:item_id]] ||= @item_map.size) @rated[u][i] = true # explicit will always have a value due to check_ratings - input << [u, i, v[value_key] || 1] + input << [u, i, @implicit ? 1 : v[:rating]] end @rated.default = nil # much more efficient than checking every value in another pass raise ArgumentError, "Missing user_id" if @user_map.key?(nil) @@ -59,11 +62,11 @@ @item_count = [0] * @item_map.size @item_sum = [0.0] * @item_map.size train_set.each do |v| i = @item_map[v[:item_id]] @item_count[i] += 1 - @item_sum[i] += (v[value_key] || 1) + @item_sum[i] += (@implicit ? 1 : v[:rating]) end end eval_set = nil if validation_set @@ -74,11 +77,11 @@ # set to non-existent item u ||= -1 i ||= -1 - eval_set << [u, i, v[value_key] || 1] + eval_set << [u, i, @implicit ? 1 : v[:rating]] end end loss = @implicit ? 12 : 0 verbose = @verbose @@ -136,12 +139,11 @@ ids = ids[indexes] elsif @user_recs_index && count predictions, ids = @user_recs_index.search(@user_factors[u, true].expand_dims(0), count + rated.size).map { |v| v[0, true] } else predictions = @item_factors.inner(@user_factors[u, true]) - # TODO make sure reverse isn't hurting performance - indexes = predictions.sort_index.reverse + indexes = predictions.sort_index.reverse # reverse just creates view indexes = indexes[0...[count + rated.size, indexes.size].min] if count predictions = predictions[indexes] ids = indexes end @@ -177,23 +179,36 @@ def top_items(count: 5) check_fit raise "top_items not computed" unless @top_items if @implicit - scores = @item_count + scores = Numo::UInt64.cast(@item_count) else require "wilson_score" range = @min_rating..@max_rating - scores = @item_sum.zip(@item_count).map { |s, c| WilsonScore.rating_lower_bound(s / c, c, range) } + scores = Numo::DFloat.cast(@item_sum.zip(@item_count).map { |s, c| WilsonScore.rating_lower_bound(s / c, c, range) }) + + # TODO uncomment in 0.3.0 + # wilson score with continuity correction + # https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval_with_continuity_correction + # z = 1.96 # 95% confidence + # range = @max_rating - @min_rating + # n = Numo::DFloat.cast(@item_count) + # phat = (Numo::DFloat.cast(@item_sum) - (@min_rating * n)) / range / n + # phat = (phat - (1 / 2 * n)).clip(0, 100) # continuity correction + # scores = (phat + z**2 / (2 * n) - z * Numo::DFloat::Math.sqrt((phat * (1 - phat) + z**2 / (4 * n)) / n)) / (1 + z**2 / n) + # scores = scores * range + @min_rating end - scores = scores.map.with_index.sort_by { |s, _| -s } - scores = scores.first(count) if count - item_ids = item_ids() - scores.map do |s, i| - {item_id: item_ids[i], score: s} + indexes = scores.sort_index.reverse + indexes = indexes[0...[count, indexes.size].min] if count + scores = scores[indexes] + + keys = @item_map.keys + indexes.size.times.map do |i| + {item_id: keys[indexes[i]], score: scores[i]} end end def user_ids @user_map.keys @@ -253,11 +268,12 @@ require "faiss" # inner product is cosine similarity with normalized vectors # https://github.com/facebookresearch/faiss/issues/95 # - # TODO use non-exact index + # TODO use non-exact index in 0.3.0 # https://github.com/facebookresearch/faiss/wiki/Faiss-indexes + # index = Faiss::IndexHNSWFlat.new(factors.shape[1], 32, :inner_product) index = Faiss::IndexFlatIP.new(factors.shape[1]) # ids are from 0...total # https://github.com/facebookresearch/faiss/blob/96b740abedffc8f67389f29c2a180913941534c6/faiss/Index.h#L89 index.add(factors)