lib/eps/base_estimator.rb in eps-0.3.4 vs lib/eps/base_estimator.rb in eps-0.3.5
- old
+ new
@@ -1,12 +1,11 @@
module Eps
class BaseEstimator
def initialize(data = nil, y = nil, **options)
@options = options.dup
- # TODO better pattern - don't pass most options to train
- options.delete(:intercept)
@trained = false
+ # TODO better pattern - don't pass most options to train
train(data, y, **options) if data
end
def predict(data)
_predict(data, false)
@@ -81,11 +80,11 @@
predictions = @evaluator.predict(data, probabilities: probabilities)
singular ? predictions.first : predictions
end
- def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
+ def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, text_features: nil, **options)
data, @target = prep_data(data, y, target, weight)
@target_type = Utils.column_type(data.label, @target)
if split.nil?
split = data.size >= 30
@@ -173,11 +172,11 @@
raise "No data in training set" if @train_set.empty?
raise "No data in validation set" if validation_set && validation_set.empty?
@validation_set = validation_set
- @evaluator = _train(verbose: verbose, early_stopping: early_stopping)
+ @evaluator = _train(**options)
# reset pmml
@pmml = nil
@trained = true
@@ -244,10 +243,10 @@
raise "Number of data points differs from weight" if data.weight && data.size != data.weight.size
end
def check_missing(c, name)
raise ArgumentError, "Missing column: #{name}" if !c
- raise ArgumentError, "Missing values in column #{name}" if c.any?(&:nil?)
+ raise ArgumentError, "Missing values in column #{name}" if c.to_a.any?(&:nil?)
end
def check_missing_value(df)
df.columns.each do |k, v|
check_missing(v, k)