lib/eps/lightgbm.rb in eps-0.3.3 vs lib/eps/lightgbm.rb in eps-0.3.4
- old
+ new
@@ -119,21 +119,23 @@
# compare a subset of predictions to check for possible bugs in evaluator
# NOTE LightGBM must use double data type for prediction input for these to be consistent
def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
expected = @booster.predict(booster_set.map_rows(&:to_a))
if objective == "multiclass"
- expected.map! do |v|
- labels[v.map.with_index.max_by { |v2, _| v2 }.last]
- end
+ actual = evaluator.predict(evaluator_set, probabilities: true)
+ # just compare first for now
+ expected.map! { |v| v.first }
+ actual.map! { |v| v.values.first }
elsif objective == "binary"
- expected.map! { |v| labels[v >= 0.5 ? 1 : 0] }
+ actual = evaluator.predict(evaluator_set, probabilities: true).map { |v| v.values.last }
+ else
+ actual = evaluator.predict(evaluator_set)
end
- actual = evaluator.predict(evaluator_set)
- regression = objective == "regression"
+ regression = objective == "regression" || objective == "binary"
bad_observations = []
expected.zip(actual).each_with_index do |(exp, act), i|
- success = regression ? (act - exp).abs < 0.001 : act == exp
+ success = (act - exp).abs < 0.001
unless success
bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first}
end
end