lib/lightgbm/booster.rb in lightgbm-0.1.0 vs lib/lightgbm/booster.rb in lightgbm-0.1.1
- old
+ new
@@ -1,20 +1,32 @@
module LightGBM
class Booster
- def initialize(model_file:)
+ def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
@handle = ::FFI::MemoryPointer.new(:pointer)
- if model_file
+ if model_str
out_num_iterations = ::FFI::MemoryPointer.new(:int)
+ check_result FFI.LGBM_BoosterLoadModelFromString(model_str, out_num_iterations, @handle)
+ elsif model_file
+ out_num_iterations = ::FFI::MemoryPointer.new(:int)
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, @handle)
+ else
+ check_result FFI.LGBM_BoosterCreate(train_set.handle_pointer, params_str(params), @handle)
end
- ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
+ # causes "Stack consistency error"
+ # ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
end
def self.finalize(pointer)
-> { FFI.LGBM_BoosterFree(pointer) }
end
+ # TODO handle name
+ def add_valid(data, name)
+ check_result FFI.LGBM_BoosterAddValidData(handle_pointer, data.handle_pointer)
+ self # consistent with Python API
+ end
+
def predict(input)
raise TypeError unless input.is_a?(Array)
singular = input.first.is_a?(Array)
input = [input] unless singular
@@ -29,18 +41,88 @@
out = out_result.read_array_of_double(out_len.read_int64)
singular ? out : out.first
end
- private
+ def save_model(filename)
+ check_result FFI.LGBM_BoosterSaveModel(handle_pointer, 0, 0, filename)
+ self # consistent with Python API
+ end
- def check_result(err)
- if err != 0
- raise FFI.LGBM_GetLastError
+ def update
+ finished = ::FFI::MemoryPointer.new(:int)
+ check_result FFI.LGBM_BoosterUpdateOneIter(handle_pointer, finished)
+ finished.read_int == 1
+ end
+
+ def feature_importance(iteration: nil, importance_type: "split")
+ iteration ||= best_iteration
+ importance_type =
+ case importance_type
+ when "split"
+ 0
+ when "gain"
+ 1
+ else
+ -1
+ end
+
+ num_features = self.num_features
+ out_result = ::FFI::MemoryPointer.new(:double, num_features)
+ check_result FFI.LGBM_BoosterFeatureImportance(handle_pointer, iteration, importance_type, out_result)
+ out_result.read_array_of_double(num_features)
+ end
+
+ def num_features
+ out = ::FFI::MemoryPointer.new(:int)
+ check_result FFI.LGBM_BoosterGetNumFeature(handle_pointer, out)
+ out.read_int
+ end
+
+ def current_iteration
+ out = ::FFI::MemoryPointer.new(:int)
+ check_result FFI::LGBM_BoosterGetCurrentIteration(handle_pointer, out)
+ out.read_int
+ end
+
+ # TODO fix
+ def best_iteration
+ -1
+ end
+
+ def model_to_string(num_iteration: nil, start_iteration: 0)
+ num_iteration ||= best_iteration
+ buffer_len = 1 << 20
+ out_len = ::FFI::MemoryPointer.new(:int64)
+ out_str = ::FFI::MemoryPointer.new(:string, buffer_len)
+ check_result FFI.LGBM_BoosterSaveModelToString(handle_pointer, start_iteration, num_iteration, buffer_len, out_len, out_str)
+ actual_len = out_len.read_int64
+ if actual_len > buffer_len
+ out_str = ::FFI::MemoryPointer.new(:string, actual_len)
+ check_result FFI.LGBM_BoosterSaveModelToString(handle_pointer, start_iteration, num_iteration, actual_len, out_len, out_str)
end
+ out_str.read_string
end
+ def to_json(num_iteration: nil, start_iteration: 0)
+ num_iteration ||= best_iteration
+ buffer_len = 1 << 20
+ out_len = ::FFI::MemoryPointer.new(:int64)
+ out_str = ::FFI::MemoryPointer.new(:string, buffer_len)
+ check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, buffer_len, out_len, out_str)
+ actual_len = out_len.read_int64
+ if actual_len > buffer_len
+ out_str = ::FFI::MemoryPointer.new(:string, actual_len)
+ check_result FFI.LGBM_BoosterDumpModel(handle_pointer, start_iteration, num_iteration, actual_len, out_len, out_str)
+ end
+ out_str.read_string
+ end
+
+ private
+
def handle_pointer
@handle.read_pointer
end
+
+ include Utils
end
end