Sha256: 3e90249869d33c9f08402735683eea9bc123fa0321f49b05a1390c4123644f63

Contents?: true

Size: 1.42 KB

Versions: 1

Compression:

Stored size: 1.42 KB

Contents

module Xgb
  class Booster
    def initialize(params: nil, model_file: nil)
      @handle = ::FFI::MemoryPointer.new(:pointer)
      check_result FFI.XGBoosterCreate(nil, 0, @handle)
      if model_file
        check_result FFI.XGBoosterLoadModel(handle_pointer, model_file)
      end

      set_param(params)
      @num_class = (params && params[:num_class]) || 1
    end

    def update(dtrain, iteration)
      check_result FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer)
    end

    def set_param(params, value = nil)
      if params.is_a?(Enumerable)
        params.each do |k, v|
          check_result FFI.XGBoosterSetParam(handle_pointer, k.to_s, v.to_s)
        end
      else
        check_result FFI.XGBoosterSetParam(handle_pointer, params.to_s, value.to_s)
      end
    end

    def predict(data, ntree_limit: nil)
      ntree_limit ||= 0
      out_len = ::FFI::MemoryPointer.new(:long)
      out_result = ::FFI::MemoryPointer.new(:pointer)
      check_result FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, out_len, out_result)
      out = out_result.read_pointer.read_array_of_float(out_len.read_long)
      out = out.each_slice(@num_class).to_a if @num_class > 1
      out
    end

    def save_model(fname)
      check_result FFI.XGBoosterSaveModel(handle_pointer, fname)
    end

    private

    def handle_pointer
      @handle.read_pointer
    end

    include Utils
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
xgb-0.1.0 lib/xgb/booster.rb