lib/lightgbm.rb in lightgbm-0.1.1 vs lib/lightgbm.rb in lightgbm-0.1.2
- old
+ new
@@ -9,16 +9,176 @@
require "lightgbm/version"
module LightGBM
class Error < StandardError; end
- def self.train(params, train_set, num_boost_round: 100, valid_sets: [], valid_names: [])
- booster = Booster.new(params: params, train_set: train_set)
- valid_sets.zip(valid_names) do |data, name|
- booster.add_valid(data, name)
+ class << self
+ def train(params, train_set,num_boost_round: 100, valid_sets: [], valid_names: [], early_stopping_rounds: nil, verbose_eval: true)
+ booster = Booster.new(params: params, train_set: train_set)
+
+ valid_contain_train = false
+ valid_sets.zip(valid_names).each_with_index do |(data, name), i|
+ if data == train_set
+ booster.train_data_name = name || "training"
+ valid_contain_train = true
+ else
+ booster.add_valid(data, name || "valid_#{i}")
+ end
+ end
+
+ booster.best_iteration = 0
+
+ if early_stopping_rounds
+ best_score = []
+ best_iter = []
+ best_message = []
+
+ puts "Training until validation scores don't improve for #{early_stopping_rounds.to_i} rounds." if verbose_eval
+ end
+
+ num_boost_round.times do |iteration|
+ booster.update
+
+ if valid_sets.any?
+ # print results
+ messages = []
+
+ if valid_contain_train
+ # not sure why reversed in output
+ booster.eval_train.reverse.each do |res|
+ messages << "%s's %s: %g" % [res[0], res[1], res[2]]
+ end
+ end
+
+ eval_valid = booster.eval_valid
+ # not sure why reversed in output
+ eval_valid.reverse.each do |res|
+ messages << "%s's %s: %g" % [res[0], res[1], res[2]]
+ end
+
+ message = "[#{iteration + 1}]\t#{messages.join("\t")}"
+
+ puts message if verbose_eval
+
+ if early_stopping_rounds
+ stop_early = false
+ eval_valid.each_with_index do |(_, _, score, higher_better), i|
+ op = higher_better ? :> : :<
+ if best_score[i].nil? || score.send(op, best_score[i])
+ best_score[i] = score
+ best_iter[i] = iteration
+ best_message[i] = message
+ elsif iteration - best_iter[i] >= early_stopping_rounds
+ booster.best_iteration = best_iter[i] + 1
+ puts "Early stopping, best iteration is:\n#{best_message[i]}" if verbose_eval
+ stop_early = true
+ break
+ end
+ end
+
+ break if stop_early
+
+ if iteration == num_boost_round - 1
+ booster.best_iteration = best_iter[0] + 1
+ puts "Did not meet early stopping. Best iteration is: #{best_message[0]}" if verbose_eval
+ end
+ end
+ end
+ end
+
+ booster
end
- num_boost_round.times do
- booster.update
+
+ def cv(params, train_set, num_boost_round: 100, nfold: 5, seed: 0, shuffle: true, early_stopping_rounds: nil, verbose_eval: nil, show_stdv: true)
+ rand_idx = (0...train_set.num_data).to_a
+ rand_idx.shuffle!(random: Random.new(seed)) if shuffle
+
+ kstep = rand_idx.size / nfold
+ test_id = rand_idx.each_slice(kstep).to_a[0...nfold]
+ train_id = []
+ nfold.times do |i|
+ idx = test_id.dup
+ idx.delete_at(i)
+ train_id << idx.flatten
+ end
+
+ boosters = []
+ folds = train_id.zip(test_id)
+ folds.each do |(train_idx, test_idx)|
+ fold_train_set = train_set.subset(train_idx)
+ fold_valid_set = train_set.subset(test_idx)
+ booster = Booster.new(params: params, train_set: fold_train_set)
+ booster.add_valid(fold_valid_set, "valid")
+ boosters << booster
+ end
+
+ eval_hist = {}
+
+ if early_stopping_rounds
+ best_score = {}
+ best_iter = {}
+ end
+
+ num_boost_round.times do |iteration|
+ boosters.each(&:update)
+
+ scores = {}
+ boosters.map(&:eval_valid).map(&:reverse).flatten(1).each do |r|
+ (scores[r[1]] ||= []) << r[2]
+ end
+
+ message_parts = ["[#{iteration + 1}]"]
+
+ means = {}
+ scores.each do |eval_name, vals|
+ mean = mean(vals)
+ stdev = stdev(vals)
+
+ (eval_hist["#{eval_name}-mean"] ||= []) << mean
+ (eval_hist["#{eval_name}-stdv"] ||= []) << stdev
+
+ means[eval_name] = mean
+
+ if show_stdv
+ message_parts << "cv_agg's %s: %g + %g" % [eval_name, mean, stdev]
+ else
+ message_parts << "cv_agg's %s: %g" % [eval_name, mean]
+ end
+ end
+
+ puts message_parts.join("\t") if verbose_eval
+
+ if early_stopping_rounds
+ stop_early = false
+ means.each do |k, score|
+ if best_score[k].nil? || score < best_score[k]
+ best_score[k] = score
+ best_iter[k] = iteration
+ elsif iteration - best_iter[k] >= early_stopping_rounds
+ stop_early = true
+ break
+ end
+ end
+ break if stop_early
+ end
+ end
+
+ eval_hist
end
- booster
+
+ private
+
+ def mean(arr)
+ arr.sum / arr.size.to_f
+ end
+
+ # don't subtract one from arr.size
+ def stdev(arr)
+ m = mean(arr)
+ sum = 0
+ arr.each do |v|
+ sum += (v - m) ** 2
+ end
+ Math.sqrt(sum / arr.size)
+ end
end
end