lib/fasttext/classifier.rb in fasttext-0.1.0 vs lib/fasttext/classifier.rb in fasttext-0.1.1

- old
+ new

@@ -28,14 +28,22 @@ input = input_path(x, y) @m ||= Ext::Model.new m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")) end - # TODO support array of text def predict(text, k: 1, threshold: 0.0) - m.predict(prep_text(text), k, threshold).map do |v| - [remove_prefix(v[1]), v[0]] - end.to_h + multiple = text.is_a?(Array) + text = [text] unless multiple + + # TODO predict multiple in C++ for performance + result = + text.map do |t| + m.predict(prep_text(t), k, threshold).map do |v| + [remove_prefix(v[1]), v[0]] + end.to_h + end + + multiple ? result : result.first end def test(x, y = nil, k: 1) input = input_path(x, y) res = m.test(input, k)