lib/disco/recommender.rb in disco-0.3.0 vs lib/disco/recommender.rb in disco-0.3.1

- old
+ new

@@ -253,10 +253,50 @@ def inspect to_s # for now end + def to_json + require "base64" + require "json" + + obj = { + implicit: @implicit, + user_ids: @user_map.keys, + item_ids: @item_map.keys, + rated: @user_map.map { |_, u| (@rated[u] || {}).keys }, + global_mean: @global_mean, + user_factors: Base64.strict_encode64(@user_factors.to_binary), + item_factors: Base64.strict_encode64(@item_factors.to_binary), + factors: @factors, + epochs: @epochs, + verbose: @verbose + } + + unless @implicit + obj[:min_rating] = @min_rating + obj[:max_rating] = @max_rating + end + + if @top_items + obj[:item_count] = @item_count + obj[:item_sum] = @item_sum + end + + JSON.generate(obj) + end + + def self.load_json(json) + require "json" + + obj = JSON.parse(json) + + recommender = new + recommender.send(:json_load, obj) + recommender + end + private # factors should already be normalized for similar users/items def create_index(factors, library:) library ||= defined?(Ngt) && !defined?(Faiss) ? "ngt" : "faiss" @@ -430,9 +470,35 @@ @top_items = obj.key?(:item_count) if @top_items @item_count = obj[:item_count] @item_sum = obj[:item_sum] + end + end + + def json_load(obj) + require "base64" + + @implicit = obj["implicit"] + @user_map = obj["user_ids"].map.with_index.to_h + @item_map = obj["item_ids"].map.with_index.to_h + @rated = obj["rated"].map.with_index.to_h { |r, i| [i, r.to_h { |v| [v, true] }] } + @global_mean = obj["global_mean"].to_f + @factors = obj["factors"].to_i + @user_factors = Numo::SFloat.from_binary(Base64.strict_decode64(obj["user_factors"]), [@user_map.size, @factors]) + @item_factors = Numo::SFloat.from_binary(Base64.strict_decode64(obj["item_factors"]), [@item_map.size, @factors]) + @epochs = obj["epochs"].to_i + @verbose = obj["verbose"] + + unless @implicit + @min_rating = obj["min_rating"] + @max_rating = obj["max_rating"] + end + + @top_items = obj.key?("item_count") + if @top_items + @item_count = obj["item_count"] + @item_sum = obj["item_sum"] end end end end