Sha256: 267e0b4ad90460ceaaf2d51e0fc7b5f98fca2c512660f915876239e15842b959

Contents?: true

Size: 1.38 KB

Versions: 4

Compression:

Stored size: 1.38 KB

Contents

require_relative "base_adapter"

module EasyML
  module Core
    class Tuner
      module Adapters
        class XGBoostAdapter < BaseAdapter
          include GlueGun::DSL

          def defaults
            {
              learning_rate: {
                min: 0.001,
                max: 0.1,
                log: true
              },
              n_estimators: {
                min: 100,
                max: 1_000
              },
              max_depth: {
                min: 2,
                max: 20
              }
            }
          end

          def configure_callbacks
            model.customize_callbacks do |callbacks|
              return unless callbacks.present?

              wandb_callback = callbacks.detect { |cb| cb.class == Wandb::XGBoostCallback }
              return unless wandb_callback.present?

              wandb_callback.project_name = "#{wandb_callback.project_name}_#{tune_started_at.strftime("%Y_%m_%d_%H_%M_%S")}"
              wandb_callback.custom_loggers = [
                lambda do |booster, _epoch, _hist|
                  dtrain = model.send(:preprocess, x_true, y_true)
                  y_pred = booster.predict(dtrain)
                  metrics = model.evaluate(y_pred: y_pred, y_true: y_true, x_true: x_true)
                  Wandb.log(metrics)
                end
              ]
            end
          end
        end
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
easy_ml-0.1.4 lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb
easy_ml-0.1.3 lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb
easy_ml-0.1.2 lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb
easy_ml-0.1.1 lib/easy_ml/core/tuner/adapters/xgboost_adapter.rb