lib/xlearn/model.rb in xlearn-0.1.0 vs lib/xlearn/model.rb in xlearn-0.1.1
- old
+ new
@@ -18,31 +18,34 @@
set_params(options)
end
def fit(x, y = nil, eval_set: nil)
- if x.is_a?(String)
- check_call FFI.XLearnSetTrain(@handle, x)
- check_call FFI.XLearnSetBool(@handle, "from_file", true)
- else
- train_set = DMatrix.new(x, label: y)
- check_call FFI.XLearnSetDMatrix(@handle, "train", train_set)
- check_call FFI.XLearnSetBool(@handle, "from_file", false)
- end
+ @model_path = nil
+ partial_fit(x, y, eval_set: eval_set)
+ end
+ def partial_fit(x, y = nil, eval_set: nil)
+ check_call FFI.XLearnSetPreModel(@handle, @model_path || "")
+
+ set_train_set(x, y)
+
if eval_set
if eval_set.is_a?(String)
check_call FFI.XLearnSetValidate(@handle, eval_set)
else
valid_set = DMatrix.new(x, label: y)
check_call FFI.XLearnSetDMatrix(@handle, "validate", valid_set)
end
end
- # TODO unlink in finalizer
- @model_file = Tempfile.new("xlearn")
+ @txt_file ||= create_tempfile
+ check_call FFI.XLearnSetTXTModel(@handle, @txt_file.path)
+
+ @model_file ||= create_tempfile
check_call FFI.XLearnFit(@handle, @model_file.path)
+ @model_path = @model_file.path
end
def predict(x, out_path: nil)
if x.is_a?(String)
check_call FFI.XLearnSetTest(@handle, x)
@@ -61,28 +64,76 @@
check_call FFI.XLearnPredictForMat(@handle, @model_file.path, length, out_arr)
out_arr.read_pointer.read_array_of_float(length.read_uint64)
end
end
+ def cv(x, y = nil, folds: nil)
+ set_params(fold: folds) if folds
+ set_train_set(x, y)
+ check_call FFI.XLearnCV(@handle)
+ end
+
def save_model(path)
raise Error, "Not trained" unless @model_file
FileUtils.cp(@model_file.path, path)
end
+ def save_txt(path)
+ raise Error, "Not trained" unless @txt_file
+ FileUtils.cp(@txt_file.path, path)
+ end
+
def load_model(path)
- @model_file ||= Tempfile.new("xlearn")
+ @model_file ||= create_tempfile
# TODO ensure tempfile is still cleaned up
FileUtils.cp(path, @model_file.path)
end
+ def bias_term
+ read_txt do |line|
+ return line.split(":").last.to_f if line.start_with?("bias:")
+ end
+ end
+
+ def linear_term
+ term = []
+ read_txt do |line|
+ if line.start_with?("i_")
+ term << line.split(":").last.to_f
+ elsif line.start_with?("v_")
+ break
+ end
+ end
+ term
+ end
+
def self.finalize(pointer)
# must use proc instead of stabby lambda
proc { FFI.XLearnHandleFree(pointer) }
end
+ def self.finalize_file(file)
+ # must use proc instead of stabby lambda
+ proc do
+ file.close
+ file.unlink
+ end
+ end
+
private
+ def set_train_set(x, y)
+ if x.is_a?(String)
+ check_call FFI.XLearnSetTrain(@handle, x)
+ check_call FFI.XLearnSetBool(@handle, "from_file", true)
+ else
+ train_set = DMatrix.new(x, label: y)
+ check_call FFI.XLearnSetDMatrix(@handle, "train", train_set)
+ check_call FFI.XLearnSetBool(@handle, "from_file", false)
+ end
+ end
+
def set_params(params)
params.each do |k, v|
k = k.to_s
ret =
case k
@@ -97,8 +148,22 @@
else
raise ArgumentError, "Invalid parameter: #{k}"
end
check_call ret
end
+ end
+
+ def read_txt
+ if @txt_file
+ File.foreach(@txt_file.path) do |line|
+ yield line
+ end
+ end
+ end
+
+ def create_tempfile
+ file = Tempfile.new("xlearn")
+ ObjectSpace.define_finalizer(self, self.class.finalize_file(file))
+ file
end
end
end