lib/xlearn/model.rb in xlearn-0.1.0 vs lib/xlearn/model.rb in xlearn-0.1.1

- old
+ new

@@ -18,31 +18,34 @@ set_params(options) end def fit(x, y = nil, eval_set: nil) - if x.is_a?(String) - check_call FFI.XLearnSetTrain(@handle, x) - check_call FFI.XLearnSetBool(@handle, "from_file", true) - else - train_set = DMatrix.new(x, label: y) - check_call FFI.XLearnSetDMatrix(@handle, "train", train_set) - check_call FFI.XLearnSetBool(@handle, "from_file", false) - end + @model_path = nil + partial_fit(x, y, eval_set: eval_set) + end + def partial_fit(x, y = nil, eval_set: nil) + check_call FFI.XLearnSetPreModel(@handle, @model_path || "") + + set_train_set(x, y) + if eval_set if eval_set.is_a?(String) check_call FFI.XLearnSetValidate(@handle, eval_set) else valid_set = DMatrix.new(x, label: y) check_call FFI.XLearnSetDMatrix(@handle, "validate", valid_set) end end - # TODO unlink in finalizer - @model_file = Tempfile.new("xlearn") + @txt_file ||= create_tempfile + check_call FFI.XLearnSetTXTModel(@handle, @txt_file.path) + + @model_file ||= create_tempfile check_call FFI.XLearnFit(@handle, @model_file.path) + @model_path = @model_file.path end def predict(x, out_path: nil) if x.is_a?(String) check_call FFI.XLearnSetTest(@handle, x) @@ -61,28 +64,76 @@ check_call FFI.XLearnPredictForMat(@handle, @model_file.path, length, out_arr) out_arr.read_pointer.read_array_of_float(length.read_uint64) end end + def cv(x, y = nil, folds: nil) + set_params(fold: folds) if folds + set_train_set(x, y) + check_call FFI.XLearnCV(@handle) + end + def save_model(path) raise Error, "Not trained" unless @model_file FileUtils.cp(@model_file.path, path) end + def save_txt(path) + raise Error, "Not trained" unless @txt_file + FileUtils.cp(@txt_file.path, path) + end + def load_model(path) - @model_file ||= Tempfile.new("xlearn") + @model_file ||= create_tempfile # TODO ensure tempfile is still cleaned up FileUtils.cp(path, @model_file.path) end + def bias_term + read_txt do |line| + return line.split(":").last.to_f if line.start_with?("bias:") + end + end + + def linear_term + term = [] + read_txt do |line| + if line.start_with?("i_") + term << line.split(":").last.to_f + elsif line.start_with?("v_") + break + end + end + term + end + def self.finalize(pointer) # must use proc instead of stabby lambda proc { FFI.XLearnHandleFree(pointer) } end + def self.finalize_file(file) + # must use proc instead of stabby lambda + proc do + file.close + file.unlink + end + end + private + def set_train_set(x, y) + if x.is_a?(String) + check_call FFI.XLearnSetTrain(@handle, x) + check_call FFI.XLearnSetBool(@handle, "from_file", true) + else + train_set = DMatrix.new(x, label: y) + check_call FFI.XLearnSetDMatrix(@handle, "train", train_set) + check_call FFI.XLearnSetBool(@handle, "from_file", false) + end + end + def set_params(params) params.each do |k, v| k = k.to_s ret = case k @@ -97,8 +148,22 @@ else raise ArgumentError, "Invalid parameter: #{k}" end check_call ret end + end + + def read_txt + if @txt_file + File.foreach(@txt_file.path) do |line| + yield line + end + end + end + + def create_tempfile + file = Tempfile.new("xlearn") + ObjectSpace.define_finalizer(self, self.class.finalize_file(file)) + file end end end