Sha256: 386629dcbeee1564f27de3028e2dc13a5d7a76f170a16f399162e8db7b2748a0

Contents?: true

Size: 786 Bytes

Versions: 1

Compression:

Stored size: 786 Bytes

Contents

module LightGBM
  class Regressor
    def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil)
      @params = {
        num_leaves: num_leaves,
        learning_rate: learning_rate
      }
      @params[:objective] = objective if objective
      @n_estimators = n_estimators
    end

    def fit(x, y)
      train_set = Dataset.new(x, label: y)
      @booster = LightGBM.train(@params, train_set, num_boost_round: @n_estimators)
      nil
    end

    def predict(data)
      @booster.predict(data)
    end

    def save_model(fname)
      @booster.save_model(fname)
    end

    def load_model(fname)
      @booster = Booster.new(params: @params, model_file: fname)
    end

    def feature_importances
      @booster.feature_importance
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
lightgbm-0.1.3 lib/lightgbm/regressor.rb