lib/disco/recommender.rb in disco-0.2.6 vs lib/disco/recommender.rb in disco-0.2.7
- old
+ new
@@ -20,10 +20,14 @@
# TODO option to set in initializer to avoid pass
# could also just check first few values
# but may be confusing if they are all missing and later ones aren't
@implicit = !train_set.any? { |v| v[:rating] }
+ if @implicit && train_set.any? { |v| v[:value] }
+ warn "[disco] WARNING: Passing `:value` with implicit feedback has no effect on recommendations and can be removed. Earlier versions of the library incorrectly stated this was used."
+ end
+
# TODO improve performance
# (catch exception instead of checking ahead of time)
unless @implicit
check_ratings(train_set)
@@ -32,19 +36,18 @@
end
end
@rated = Hash.new { |hash, key| hash[key] = {} }
input = []
- value_key = @implicit ? :value : :rating
train_set.each do |v|
# update maps and build matrix in single pass
u = (@user_map[v[:user_id]] ||= @user_map.size)
i = (@item_map[v[:item_id]] ||= @item_map.size)
@rated[u][i] = true
# explicit will always have a value due to check_ratings
- input << [u, i, v[value_key] || 1]
+ input << [u, i, @implicit ? 1 : v[:rating]]
end
@rated.default = nil
# much more efficient than checking every value in another pass
raise ArgumentError, "Missing user_id" if @user_map.key?(nil)
@@ -59,11 +62,11 @@
@item_count = [0] * @item_map.size
@item_sum = [0.0] * @item_map.size
train_set.each do |v|
i = @item_map[v[:item_id]]
@item_count[i] += 1
- @item_sum[i] += (v[value_key] || 1)
+ @item_sum[i] += (@implicit ? 1 : v[:rating])
end
end
eval_set = nil
if validation_set
@@ -74,11 +77,11 @@
# set to non-existent item
u ||= -1
i ||= -1
- eval_set << [u, i, v[value_key] || 1]
+ eval_set << [u, i, @implicit ? 1 : v[:rating]]
end
end
loss = @implicit ? 12 : 0
verbose = @verbose
@@ -136,12 +139,11 @@
ids = ids[indexes]
elsif @user_recs_index && count
predictions, ids = @user_recs_index.search(@user_factors[u, true].expand_dims(0), count + rated.size).map { |v| v[0, true] }
else
predictions = @item_factors.inner(@user_factors[u, true])
- # TODO make sure reverse isn't hurting performance
- indexes = predictions.sort_index.reverse
+ indexes = predictions.sort_index.reverse # reverse just creates view
indexes = indexes[0...[count + rated.size, indexes.size].min] if count
predictions = predictions[indexes]
ids = indexes
end
@@ -177,23 +179,36 @@
def top_items(count: 5)
check_fit
raise "top_items not computed" unless @top_items
if @implicit
- scores = @item_count
+ scores = Numo::UInt64.cast(@item_count)
else
require "wilson_score"
range = @min_rating..@max_rating
- scores = @item_sum.zip(@item_count).map { |s, c| WilsonScore.rating_lower_bound(s / c, c, range) }
+ scores = Numo::DFloat.cast(@item_sum.zip(@item_count).map { |s, c| WilsonScore.rating_lower_bound(s / c, c, range) })
+
+ # TODO uncomment in 0.3.0
+ # wilson score with continuity correction
+ # https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval_with_continuity_correction
+ # z = 1.96 # 95% confidence
+ # range = @max_rating - @min_rating
+ # n = Numo::DFloat.cast(@item_count)
+ # phat = (Numo::DFloat.cast(@item_sum) - (@min_rating * n)) / range / n
+ # phat = (phat - (1 / 2 * n)).clip(0, 100) # continuity correction
+ # scores = (phat + z**2 / (2 * n) - z * Numo::DFloat::Math.sqrt((phat * (1 - phat) + z**2 / (4 * n)) / n)) / (1 + z**2 / n)
+ # scores = scores * range + @min_rating
end
- scores = scores.map.with_index.sort_by { |s, _| -s }
- scores = scores.first(count) if count
- item_ids = item_ids()
- scores.map do |s, i|
- {item_id: item_ids[i], score: s}
+ indexes = scores.sort_index.reverse
+ indexes = indexes[0...[count, indexes.size].min] if count
+ scores = scores[indexes]
+
+ keys = @item_map.keys
+ indexes.size.times.map do |i|
+ {item_id: keys[indexes[i]], score: scores[i]}
end
end
def user_ids
@user_map.keys
@@ -253,11 +268,12 @@
require "faiss"
# inner product is cosine similarity with normalized vectors
# https://github.com/facebookresearch/faiss/issues/95
#
- # TODO use non-exact index
+ # TODO use non-exact index in 0.3.0
# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
+ # index = Faiss::IndexHNSWFlat.new(factors.shape[1], 32, :inner_product)
index = Faiss::IndexFlatIP.new(factors.shape[1])
# ids are from 0...total
# https://github.com/facebookresearch/faiss/blob/96b740abedffc8f67389f29c2a180913941534c6/faiss/Index.h#L89
index.add(factors)