Sha256: a9b0ed935251b2ed7977a3c1523ef93f0e9a0045911da3f790c9613e3df61a4c
Contents?: true
Size: 901 Bytes
Versions: 18
Compression:
Stored size: 901 Bytes
Contents
module LightGBM class Regressor < Model def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: "regression", **options) super end def fit(x, y, categorical_feature: "auto", eval_set: nil, eval_names: [], early_stopping_rounds: nil, verbose: true) train_set = Dataset.new(x, label: y, categorical_feature: categorical_feature, params: @params) valid_sets = Array(eval_set).map { |v| Dataset.new(v[0], label: v[1], reference: train_set, params: @params) } @booster = LightGBM.train(@params, train_set, num_boost_round: @n_estimators, early_stopping_rounds: early_stopping_rounds, verbose_eval: verbose, valid_sets: valid_sets, valid_names: eval_names ) nil end def predict(data, num_iteration: nil) @booster.predict(data, num_iteration: num_iteration) end end end
Version data entries
18 entries across 18 versions & 1 rubygems