lib/eps/base_estimator.rb in eps-0.3.4 vs lib/eps/base_estimator.rb in eps-0.3.5

- old
+ new

@@ -1,12 +1,11 @@ module Eps class BaseEstimator def initialize(data = nil, y = nil, **options) @options = options.dup - # TODO better pattern - don't pass most options to train - options.delete(:intercept) @trained = false + # TODO better pattern - don't pass most options to train train(data, y, **options) if data end def predict(data) _predict(data, false) @@ -81,11 +80,11 @@ predictions = @evaluator.predict(data, probabilities: probabilities) singular ? predictions.first : predictions end - def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil) + def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, text_features: nil, **options) data, @target = prep_data(data, y, target, weight) @target_type = Utils.column_type(data.label, @target) if split.nil? split = data.size >= 30 @@ -173,11 +172,11 @@ raise "No data in training set" if @train_set.empty? raise "No data in validation set" if validation_set && validation_set.empty? @validation_set = validation_set - @evaluator = _train(verbose: verbose, early_stopping: early_stopping) + @evaluator = _train(**options) # reset pmml @pmml = nil @trained = true @@ -244,10 +243,10 @@ raise "Number of data points differs from weight" if data.weight && data.size != data.weight.size end def check_missing(c, name) raise ArgumentError, "Missing column: #{name}" if !c - raise ArgumentError, "Missing values in column #{name}" if c.any?(&:nil?) + raise ArgumentError, "Missing values in column #{name}" if c.to_a.any?(&:nil?) end def check_missing_value(df) df.columns.each do |k, v| check_missing(v, k)