Sha256: 7f8ba2f838095778da5d75a239dd1adb35ef3a302f362585ff8f3cf67636b998

Contents?: true

Size: 738 Bytes

Versions: 16

Compression:

Stored size: 738 Bytes

Contents

module XGBoost
  class Model
    attr_reader :booster

    def initialize(n_estimators: 100, importance_type: "gain", **options)
      @params = options
      @n_estimators = n_estimators
      @importance_type = importance_type
    end

    def predict(data)
      dmat = DMatrix.new(data)
      @booster.predict(dmat)
    end

    def save_model(fname)
      @booster.save_model(fname)
    end

    def load_model(fname)
      @booster = Booster.new(params: @params, model_file: fname)
    end

    def feature_importances
      score = @booster.score(importance_type: @importance_type)
      scores = @booster.feature_names.map { |k| score[k] || 0.0 }
      total = scores.sum.to_f
      scores.map { |s| s / total }
    end
  end
end

Version data entries

16 entries across 16 versions & 2 rubygems

Version Path
honzasterba_xgb-0.9.0 lib/xgboost/model.rb
xgb-0.9.0 lib/xgboost/model.rb
xgb-0.8.0 lib/xgboost/model.rb
xgb-0.7.3 lib/xgboost/model.rb
xgb-0.7.2 lib/xgboost/model.rb
xgb-0.7.1 lib/xgboost/model.rb
xgb-0.7.0 lib/xgboost/model.rb
xgb-0.6.0 lib/xgboost/model.rb
xgb-0.5.3 lib/xgboost/model.rb
xgb-0.5.2 lib/xgboost/model.rb
xgb-0.5.1 lib/xgboost/model.rb
xgb-0.5.0 lib/xgboost/model.rb
xgb-0.4.1 lib/xgboost/model.rb
xgb-0.4.0 lib/xgboost/model.rb
xgb-0.3.1 lib/xgboost/model.rb
xgb-0.3.0 lib/xgboost/model.rb