Sha256: badef0de62d0c791950c91e3948caef3b1750401c79c374af80a8051bd1f73c4
Contents?: true
Size: 1.42 KB
Versions: 1
Compression:
Stored size: 1.42 KB
Contents
module LightGBM class Classifier def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil) @params = { num_leaves: num_leaves, learning_rate: learning_rate } @params[:objective] = objective if objective @n_estimators = n_estimators end def fit(x, y) n_classes = y.uniq.size params = @params.dup if n_classes > 2 params[:objective] ||= "multiclass" params[:num_class] = n_classes else params[:objective] ||= "binary" end train_set = Dataset.new(x, label: y) @booster = LightGBM.train(params, train_set, num_boost_round: @n_estimators) nil end def predict(data) y_pred = @booster.predict(data) if y_pred.first.is_a?(Array) # multiple classes y_pred.map do |v| v.map.with_index.max_by { |v2, i| v2 }.last end else y_pred.map { |v| v > 0.5 ? 1 : 0 } end end def predict_proba(data) y_pred = @booster.predict(data) if y_pred.first.is_a?(Array) # multiple classes y_pred else y_pred.map { |v| [1 - v, v] } end end def save_model(fname) @booster.save_model(fname) end def load_model(fname) @booster = Booster.new(params: @params, model_file: fname) end def feature_importances @booster.feature_importance end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
lightgbm-0.1.3 | lib/lightgbm/classifier.rb |