Sha256: 3e90249869d33c9f08402735683eea9bc123fa0321f49b05a1390c4123644f63
Contents?: true
Size: 1.42 KB
Versions: 1
Compression:
Stored size: 1.42 KB
Contents
module Xgb class Booster def initialize(params: nil, model_file: nil) @handle = ::FFI::MemoryPointer.new(:pointer) check_result FFI.XGBoosterCreate(nil, 0, @handle) if model_file check_result FFI.XGBoosterLoadModel(handle_pointer, model_file) end set_param(params) @num_class = (params && params[:num_class]) || 1 end def update(dtrain, iteration) check_result FFI.XGBoosterUpdateOneIter(handle_pointer, iteration, dtrain.handle_pointer) end def set_param(params, value = nil) if params.is_a?(Enumerable) params.each do |k, v| check_result FFI.XGBoosterSetParam(handle_pointer, k.to_s, v.to_s) end else check_result FFI.XGBoosterSetParam(handle_pointer, params.to_s, value.to_s) end end def predict(data, ntree_limit: nil) ntree_limit ||= 0 out_len = ::FFI::MemoryPointer.new(:long) out_result = ::FFI::MemoryPointer.new(:pointer) check_result FFI.XGBoosterPredict(handle_pointer, data.handle_pointer, 0, ntree_limit, out_len, out_result) out = out_result.read_pointer.read_array_of_float(out_len.read_long) out = out.each_slice(@num_class).to_a if @num_class > 1 out end def save_model(fname) check_result FFI.XGBoosterSaveModel(handle_pointer, fname) end private def handle_pointer @handle.read_pointer end include Utils end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
xgb-0.1.0 | lib/xgb/booster.rb |