lib/fasttext/classifier.rb in fasttext-0.1.0 vs lib/fasttext/classifier.rb in fasttext-0.1.1
- old
+ new
@@ -28,14 +28,22 @@
input = input_path(x, y)
@m ||= Ext::Model.new
m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
end
- # TODO support array of text
def predict(text, k: 1, threshold: 0.0)
- m.predict(prep_text(text), k, threshold).map do |v|
- [remove_prefix(v[1]), v[0]]
- end.to_h
+ multiple = text.is_a?(Array)
+ text = [text] unless multiple
+
+ # TODO predict multiple in C++ for performance
+ result =
+ text.map do |t|
+ m.predict(prep_text(t), k, threshold).map do |v|
+ [remove_prefix(v[1]), v[0]]
+ end.to_h
+ end
+
+ multiple ? result : result.first
end
def test(x, y = nil, k: 1)
input = input_path(x, y)
res = m.test(input, k)