Sha256: 32842d96ecf5d4addb8992e82eb8046304b2dc970fecca3c3b8828cde0bcdc2e

Contents?: true

Size: 1.57 KB

Versions: 18

Compression:

Stored size: 1.57 KB

Contents

module LightGBM
  class Classifier < Model
    def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil, **options)
      super
    end

    def fit(x, y, eval_set: nil, eval_names: [], categorical_feature: "auto", early_stopping_rounds: nil, verbose: true)
      n_classes = y.uniq.size

      params = @params.dup
      if n_classes > 2
        params[:objective] ||= "multiclass"
        params[:num_class] = n_classes
      else
        params[:objective] ||= "binary"
      end

      train_set = Dataset.new(x, label: y, categorical_feature: categorical_feature, params: params)
      valid_sets = Array(eval_set).map { |v| Dataset.new(v[0], label: v[1], reference: train_set, params: params) }

      @booster = LightGBM.train(params, train_set,
        num_boost_round: @n_estimators,
        early_stopping_rounds: early_stopping_rounds,
        verbose_eval: verbose,
        valid_sets: valid_sets,
        valid_names: eval_names
      )
      nil
    end

    def predict(data, num_iteration: nil)
      y_pred = @booster.predict(data, num_iteration: num_iteration)

      if y_pred.first.is_a?(Array)
        # multiple classes
        y_pred.map do |v|
          v.map.with_index.max_by { |v2, _| v2 }.last
        end
      else
        y_pred.map { |v| v > 0.5 ? 1 : 0 }
      end
    end

    def predict_proba(data, num_iteration: nil)
      y_pred = @booster.predict(data, num_iteration: num_iteration)

      if y_pred.first.is_a?(Array)
        # multiple classes
        y_pred
      else
        y_pred.map { |v| [1 - v, v] }
      end
    end
  end
end

Version data entries

18 entries across 18 versions & 1 rubygems

Version Path
lightgbm-0.3.4 lib/lightgbm/classifier.rb
lightgbm-0.3.3 lib/lightgbm/classifier.rb
lightgbm-0.3.2 lib/lightgbm/classifier.rb
lightgbm-0.3.1 lib/lightgbm/classifier.rb
lightgbm-0.3.0 lib/lightgbm/classifier.rb
lightgbm-0.2.7 lib/lightgbm/classifier.rb
lightgbm-0.2.6 lib/lightgbm/classifier.rb
lightgbm-0.2.5 lib/lightgbm/classifier.rb
lightgbm-0.2.4 lib/lightgbm/classifier.rb
lightgbm-0.2.3 lib/lightgbm/classifier.rb
lightgbm-0.2.2 lib/lightgbm/classifier.rb
lightgbm-0.2.1 lib/lightgbm/classifier.rb
lightgbm-0.2.0 lib/lightgbm/classifier.rb
lightgbm-0.1.9 lib/lightgbm/classifier.rb
lightgbm-0.1.8 lib/lightgbm/classifier.rb
lightgbm-0.1.7 lib/lightgbm/classifier.rb
lightgbm-0.1.6 lib/lightgbm/classifier.rb
lightgbm-0.1.5 lib/lightgbm/classifier.rb