lib/lightgbm.rb in lightgbm-0.1.7 vs lib/lightgbm.rb in lightgbm-0.1.8
- old
+ new
@@ -34,10 +34,12 @@
valid_sets.zip(valid_names).each_with_index do |(data, name), i|
if data == train_set
booster.train_data_name = name || "training"
valid_contain_train = true
else
+ # ensure the validation set references the training set
+ data.reference = train_set
booster.add_valid(data, name || "valid_#{i}")
end
end
raise ArgumentError, "For early stopping, at least one validation set is required" if early_stopping_rounds && !valid_sets.any? { |v| v != train_set }
@@ -131,10 +133,11 @@
eval_hist = {}
if early_stopping_rounds
best_score = {}
best_iter = {}
+ best_iteration = nil
end
num_boost_round.times do |iteration|
boosters.each(&:update)
@@ -170,14 +173,24 @@
# TODO fix higher better
if best_score[k].nil? || score < best_score[k]
best_score[k] = score
best_iter[k] = iteration
elsif iteration - best_iter[k] >= early_stopping_rounds
+ best_iteration = best_iter[k]
stop_early = true
break
end
end
break if stop_early
+ end
+ end
+
+ if early_stopping_rounds
+ # use best iteration from first metric if not stopped early
+ best_iteration ||= best_iter[best_iter.keys.first]
+ eval_hist.each_key do |k|
+ # TODO uncomment for 0.2.0
+ # eval_hist[k] = eval_hist[k].first(best_iteration + 1)
end
end
eval_hist
end