lib/disco/recommender.rb in disco-0.1.2 vs lib/disco/recommender.rb in disco-0.1.3

- old
+ new

@@ -7,18 +7,12 @@ @epochs = epochs @verbose = verbose end def fit(train_set, validation_set: nil) - if defined?(Daru) - if train_set.is_a?(Daru::DataFrame) - train_set = train_set.to_a[0] - end - if validation_set.is_a?(Daru::DataFrame) - validation_set = validation_set.to_a[0] - end - end + train_set = to_dataset(train_set) + validation_set = to_dataset(validation_set) if validation_set @implicit = !train_set.any? { |v| v[:rating] } unless @implicit ratings = train_set.map { |o| o[:rating] } @@ -188,10 +182,13 @@ def create_maps(train_set) user_ids = train_set.map { |v| v[:user_id] }.uniq.sort item_ids = train_set.map { |v| v[:item_id] }.uniq.sort + raise ArgumentError, "Missing user_id" if user_ids.any?(&:nil?) + raise ArgumentError, "Missing item_id" if item_ids.any?(&:nil?) + @user_map = user_ids.zip(user_ids.size.times).to_h @item_map = item_ids.zip(item_ids.size.times).to_h end def check_ratings(ratings) @@ -203,9 +200,28 @@ end end def check_training_set(train_set) raise ArgumentError, "No training data" if train_set.empty? + end + + def to_dataset(dataset) + if defined?(Rover::DataFrame) && dataset.is_a?(Rover::DataFrame) + # convert keys to symbols + dataset = dataset.dup + dataset.keys.each do |k, v| + dataset[k.to_sym] ||= dataset.delete(k) + end + dataset.to_a + elsif defined?(Daru::DataFrame) && dataset.is_a?(Daru::DataFrame) + # convert keys to symbols + dataset = dataset.dup + new_names = dataset.vectors.to_a.map { |k| [k, k.to_sym] }.to_h + dataset.rename_vectors!(new_names) + dataset.to_a[0] + else + dataset + end end def marshal_dump obj = { implicit: @implicit,