lib/eps/base_estimator.rb in eps-0.3.3 vs lib/eps/base_estimator.rb in eps-0.3.4
- old
+ new
@@ -7,31 +7,15 @@
@trained = false
train(data, y, **options) if data
end
def predict(data)
- singular = data.is_a?(Hash)
- data = [data] if singular
+ _predict(data, false)
+ end
- data = Eps::DataFrame.new(data)
-
- @evaluator.features.each do |k, type|
- values = data.columns[k]
- raise ArgumentError, "Missing column: #{k}" if !values
- column_type = Utils.column_type(values.compact, k) if values
-
- if !column_type.nil?
- if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical")
- raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}"
- end
- end
- # TODO check for unknown values for categorical features
- end
-
- predictions = @evaluator.predict(data)
-
- singular ? predictions.first : predictions
+ def predict_probability(data)
+ _predict(data, true)
end
def evaluate(data, y = nil, target: nil, weight: nil)
data, target = prep_data(data, y, target || @target, weight)
Eps.metrics(data.label, predict(data), weight: data.weight)
@@ -72,9 +56,33 @@
str << _summary(extended: extended)
str
end
private
+
+ def _predict(data, probabilities)
+ singular = data.is_a?(Hash)
+ data = [data] if singular
+
+ data = Eps::DataFrame.new(data)
+
+ @evaluator.features.each do |k, type|
+ values = data.columns[k]
+ raise ArgumentError, "Missing column: #{k}" if !values
+ column_type = Utils.column_type(values.compact, k) if values
+
+ if !column_type.nil?
+ if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical")
+ raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}"
+ end
+ end
+ # TODO check for unknown values for categorical features
+ end
+
+ predictions = @evaluator.predict(data, probabilities: probabilities)
+
+ singular ? predictions.first : predictions
+ end
def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
data, @target = prep_data(data, y, target, weight)
@target_type = Utils.column_type(data.label, @target)