lib/eps/base_estimator.rb in eps-0.3.3 vs lib/eps/base_estimator.rb in eps-0.3.4

- old
+ new

@@ -7,31 +7,15 @@ @trained = false train(data, y, **options) if data end def predict(data) - singular = data.is_a?(Hash) - data = [data] if singular + _predict(data, false) + end - data = Eps::DataFrame.new(data) - - @evaluator.features.each do |k, type| - values = data.columns[k] - raise ArgumentError, "Missing column: #{k}" if !values - column_type = Utils.column_type(values.compact, k) if values - - if !column_type.nil? - if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical") - raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}" - end - end - # TODO check for unknown values for categorical features - end - - predictions = @evaluator.predict(data) - - singular ? predictions.first : predictions + def predict_probability(data) + _predict(data, true) end def evaluate(data, y = nil, target: nil, weight: nil) data, target = prep_data(data, y, target || @target, weight) Eps.metrics(data.label, predict(data), weight: data.weight) @@ -72,9 +56,33 @@ str << _summary(extended: extended) str end private + + def _predict(data, probabilities) + singular = data.is_a?(Hash) + data = [data] if singular + + data = Eps::DataFrame.new(data) + + @evaluator.features.each do |k, type| + values = data.columns[k] + raise ArgumentError, "Missing column: #{k}" if !values + column_type = Utils.column_type(values.compact, k) if values + + if !column_type.nil? + if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical") + raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}" + end + end + # TODO check for unknown values for categorical features + end + + predictions = @evaluator.predict(data, probabilities: probabilities) + + singular ? predictions.first : predictions + end def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil) data, @target = prep_data(data, y, target, weight) @target_type = Utils.column_type(data.label, @target)