Sha256: badef0de62d0c791950c91e3948caef3b1750401c79c374af80a8051bd1f73c4

Contents?: true

Size: 1.42 KB

Versions: 1

Compression:

Stored size: 1.42 KB

Contents

module LightGBM
  class Classifier
    def initialize(num_leaves: 31, learning_rate: 0.1, n_estimators: 100, objective: nil)
      @params = {
        num_leaves: num_leaves,
        learning_rate: learning_rate
      }
      @params[:objective] = objective if objective
      @n_estimators = n_estimators
    end

    def fit(x, y)
      n_classes = y.uniq.size

      params = @params.dup
      if n_classes > 2
        params[:objective] ||= "multiclass"
        params[:num_class] = n_classes
      else
        params[:objective] ||= "binary"
      end

      train_set = Dataset.new(x, label: y)
      @booster = LightGBM.train(params, train_set, num_boost_round: @n_estimators)
      nil
    end

    def predict(data)
      y_pred = @booster.predict(data)

      if y_pred.first.is_a?(Array)
        # multiple classes
        y_pred.map do |v|
          v.map.with_index.max_by { |v2, i| v2 }.last
        end
      else
        y_pred.map { |v| v > 0.5 ? 1 : 0 }
      end
    end

    def predict_proba(data)
      y_pred = @booster.predict(data)

      if y_pred.first.is_a?(Array)
        # multiple classes
        y_pred
      else
        y_pred.map { |v| [1 - v, v] }
      end
    end

    def save_model(fname)
      @booster.save_model(fname)
    end

    def load_model(fname)
      @booster = Booster.new(params: @params, model_file: fname)
    end

    def feature_importances
      @booster.feature_importance
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
lightgbm-0.1.3 lib/lightgbm/classifier.rb