lib/eps/base_estimator.rb in eps-0.3.5 vs lib/eps/base_estimator.rb in eps-0.3.6
- old
+ new
@@ -1,10 +1,11 @@
module Eps
class BaseEstimator
def initialize(data = nil, y = nil, **options)
@options = options.dup
@trained = false
+ @text_encoders = {}
# TODO better pattern - don't pass most options to train
train(data, y, **options) if data
end
def predict(data)
@@ -207,32 +208,41 @@
check_data(data)
[data, target]
end
- def prep_text_features(train_set)
- @text_encoders = {}
+ def prep_text_features(train_set, fit: true)
@text_features.each do |k, v|
- # reset vocabulary
- v.delete(:vocabulary)
+ if fit
+ # reset vocabulary
+ v.delete(:vocabulary)
- # TODO determine max features automatically
- # start based on number of rows
- encoder = Eps::TextEncoder.new(**v)
- counts = encoder.fit(train_set.columns.delete(k))
+ # TODO determine max features automatically
+ # start based on number of rows
+ encoder = Eps::TextEncoder.new(**v)
+ counts = encoder.fit(train_set.columns.delete(k))
+ else
+ encoder = @text_encoders[k]
+ counts = encoder.transform(train_set.columns.delete(k))
+ end
+
encoder.vocabulary.each do |word|
train_set.columns[[k, word]] = [0] * counts.size
end
+
counts.each_with_index do |ci, i|
ci.each do |word, count|
word_key = [k, word]
train_set.columns[word_key][i] = 1 if train_set.columns.key?(word_key)
end
end
- @text_encoders[k] = encoder
- # update vocabulary
- v[:vocabulary] = encoder.vocabulary
+ if fit
+ @text_encoders[k] = encoder
+
+ # update vocabulary
+ v[:vocabulary] = encoder.vocabulary
+ end
end
raise "No features left" if train_set.columns.empty?
end