lib/eps/lightgbm.rb in eps-0.3.3 vs lib/eps/lightgbm.rb in eps-0.3.4

- old
+ new

@@ -119,21 +119,23 @@ # compare a subset of predictions to check for possible bugs in evaluator # NOTE LightGBM must use double data type for prediction input for these to be consistent def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set) expected = @booster.predict(booster_set.map_rows(&:to_a)) if objective == "multiclass" - expected.map! do |v| - labels[v.map.with_index.max_by { |v2, _| v2 }.last] - end + actual = evaluator.predict(evaluator_set, probabilities: true) + # just compare first for now + expected.map! { |v| v.first } + actual.map! { |v| v.values.first } elsif objective == "binary" - expected.map! { |v| labels[v >= 0.5 ? 1 : 0] } + actual = evaluator.predict(evaluator_set, probabilities: true).map { |v| v.values.last } + else + actual = evaluator.predict(evaluator_set) end - actual = evaluator.predict(evaluator_set) - regression = objective == "regression" + regression = objective == "regression" || objective == "binary" bad_observations = [] expected.zip(actual).each_with_index do |(exp, act), i| - success = regression ? (act - exp).abs < 0.001 : act == exp + success = (act - exp).abs < 0.001 unless success bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first} end end