lib/eps/base_estimator.rb in eps-0.3.5 vs lib/eps/base_estimator.rb in eps-0.3.6

- old
+ new

@@ -1,10 +1,11 @@ module Eps class BaseEstimator def initialize(data = nil, y = nil, **options) @options = options.dup @trained = false + @text_encoders = {} # TODO better pattern - don't pass most options to train train(data, y, **options) if data end def predict(data) @@ -207,32 +208,41 @@ check_data(data) [data, target] end - def prep_text_features(train_set) - @text_encoders = {} + def prep_text_features(train_set, fit: true) @text_features.each do |k, v| - # reset vocabulary - v.delete(:vocabulary) + if fit + # reset vocabulary + v.delete(:vocabulary) - # TODO determine max features automatically - # start based on number of rows - encoder = Eps::TextEncoder.new(**v) - counts = encoder.fit(train_set.columns.delete(k)) + # TODO determine max features automatically + # start based on number of rows + encoder = Eps::TextEncoder.new(**v) + counts = encoder.fit(train_set.columns.delete(k)) + else + encoder = @text_encoders[k] + counts = encoder.transform(train_set.columns.delete(k)) + end + encoder.vocabulary.each do |word| train_set.columns[[k, word]] = [0] * counts.size end + counts.each_with_index do |ci, i| ci.each do |word, count| word_key = [k, word] train_set.columns[word_key][i] = 1 if train_set.columns.key?(word_key) end end - @text_encoders[k] = encoder - # update vocabulary - v[:vocabulary] = encoder.vocabulary + if fit + @text_encoders[k] = encoder + + # update vocabulary + v[:vocabulary] = encoder.vocabulary + end end raise "No features left" if train_set.columns.empty? end