lib/fasttext/classifier.rb in fasttext-0.2.1 vs lib/fasttext/classifier.rb in fasttext-0.2.2
- old
+ new
@@ -28,48 +28,51 @@
autotune_duration: 60 * 5,
autotune_model_size: ""
}
def fit(x, y = nil, autotune_set: nil)
- input = input_path(x, y)
+ input, _ref = input_path(x, y)
@m ||= Ext::Model.new
- opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
+ a = build_args(DEFAULT_OPTIONS)
+ a.input = input
+ a.model = "supervised"
if autotune_set
x, y = autotune_set
- opts.merge!(autotune_validation_file: input_path(x, y))
+ a.autotune_validation_file, _autotune_ref = input_path(x, y)
end
- m.train(opts)
+ m.train(a)
end
def predict(text, k: 1, threshold: 0.0)
multiple = text.is_a?(Array)
text = [text] unless multiple
# TODO predict multiple in C++ for performance
result =
text.map do |t|
- m.predict(prep_text(t), k, threshold).map do |v|
+ m.predict(prep_text(t), k, threshold).to_h do |v|
[remove_prefix(v[1]), v[0]]
- end.to_h
+ end
end
multiple ? result : result.first
end
def test(x, y = nil, k: 1)
- input = input_path(x, y)
+ input, _ref = input_path(x, y)
res = m.test(input, k)
{
examples: res[0],
precision: res[1],
recall: res[2]
}
end
# TODO support options
def quantize
- m.quantize({})
+ a = Ext::Args.new
+ m.quantize(a)
end
def labels(include_freq: false)
labels, freqs = m.labels
labels.map! { |v| remove_prefix(v) }
@@ -83,20 +86,20 @@
private
def input_path(x, y)
if x.is_a?(String)
raise ArgumentError, "Cannot pass y with file" if y
- x
+ [x, nil]
else
tempfile = Tempfile.new("fasttext")
x.zip(y) do |xi, yi|
parts = Array(yi).map { |label| "__label__" + label }
parts << xi.gsub("\n", " ") # replace newlines in document
tempfile.write(parts.join(" "))
tempfile.write("\n")
end
tempfile.close
- tempfile.path
+ [tempfile.path, tempfile]
end
end
def remove_prefix(label)
label.sub(label_prefix, "")