lib/disco/recommender.rb in disco-0.1.2 vs lib/disco/recommender.rb in disco-0.1.3
- old
+ new
@@ -7,18 +7,12 @@
@epochs = epochs
@verbose = verbose
end
def fit(train_set, validation_set: nil)
- if defined?(Daru)
- if train_set.is_a?(Daru::DataFrame)
- train_set = train_set.to_a[0]
- end
- if validation_set.is_a?(Daru::DataFrame)
- validation_set = validation_set.to_a[0]
- end
- end
+ train_set = to_dataset(train_set)
+ validation_set = to_dataset(validation_set) if validation_set
@implicit = !train_set.any? { |v| v[:rating] }
unless @implicit
ratings = train_set.map { |o| o[:rating] }
@@ -188,10 +182,13 @@
def create_maps(train_set)
user_ids = train_set.map { |v| v[:user_id] }.uniq.sort
item_ids = train_set.map { |v| v[:item_id] }.uniq.sort
+ raise ArgumentError, "Missing user_id" if user_ids.any?(&:nil?)
+ raise ArgumentError, "Missing item_id" if item_ids.any?(&:nil?)
+
@user_map = user_ids.zip(user_ids.size.times).to_h
@item_map = item_ids.zip(item_ids.size.times).to_h
end
def check_ratings(ratings)
@@ -203,9 +200,28 @@
end
end
def check_training_set(train_set)
raise ArgumentError, "No training data" if train_set.empty?
+ end
+
+ def to_dataset(dataset)
+ if defined?(Rover::DataFrame) && dataset.is_a?(Rover::DataFrame)
+ # convert keys to symbols
+ dataset = dataset.dup
+ dataset.keys.each do |k, v|
+ dataset[k.to_sym] ||= dataset.delete(k)
+ end
+ dataset.to_a
+ elsif defined?(Daru::DataFrame) && dataset.is_a?(Daru::DataFrame)
+ # convert keys to symbols
+ dataset = dataset.dup
+ new_names = dataset.vectors.to_a.map { |k| [k, k.to_sym] }.to_h
+ dataset.rename_vectors!(new_names)
+ dataset.to_a[0]
+ else
+ dataset
+ end
end
def marshal_dump
obj = {
implicit: @implicit,