lib/eps/base_estimator.rb in eps-0.3.2 vs lib/eps/base_estimator.rb in eps-0.3.3
- old
+ new
@@ -2,10 +2,11 @@
class BaseEstimator
def initialize(data = nil, y = nil, **options)
@options = options.dup
# TODO better pattern - don't pass most options to train
options.delete(:intercept)
+ @trained = false
train(data, y, **options) if data
end
def predict(data)
singular = data.is_a?(Hash)
@@ -46,10 +47,12 @@
model.instance_variable_set("@pmml", pmml.respond_to?(:to_xml) ? pmml.to_xml : pmml) # cache data
model
end
def summary(extended: false)
+ raise "Summary not available for loaded models" unless @trained
+
str = String.new("")
if @validation_set
y_true = @validation_set.label
y_pred = predict(@validation_set)
@@ -167,10 +170,12 @@
@evaluator = _train(verbose: verbose, early_stopping: early_stopping)
# reset pmml
@pmml = nil
+ @trained = true
+
nil
end
def prep_data(data, y, target, weight)
data = Eps::DataFrame.new(data)
@@ -203,10 +208,10 @@
# reset vocabulary
v.delete(:vocabulary)
# TODO determine max features automatically
# start based on number of rows
- encoder = Eps::TextEncoder.new(v)
+ encoder = Eps::TextEncoder.new(**v)
counts = encoder.fit(train_set.columns.delete(k))
encoder.vocabulary.each do |word|
train_set.columns[[k, word]] = [0] * counts.size
end
counts.each_with_index do |ci, i|