lib/fasttext/classifier.rb in fasttext-0.2.1 vs lib/fasttext/classifier.rb in fasttext-0.2.2

- old
+ new

@@ -28,48 +28,51 @@ autotune_duration: 60 * 5, autotune_model_size: "" } def fit(x, y = nil, autotune_set: nil) - input = input_path(x, y) + input, _ref = input_path(x, y) @m ||= Ext::Model.new - opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised") + a = build_args(DEFAULT_OPTIONS) + a.input = input + a.model = "supervised" if autotune_set x, y = autotune_set - opts.merge!(autotune_validation_file: input_path(x, y)) + a.autotune_validation_file, _autotune_ref = input_path(x, y) end - m.train(opts) + m.train(a) end def predict(text, k: 1, threshold: 0.0) multiple = text.is_a?(Array) text = [text] unless multiple # TODO predict multiple in C++ for performance result = text.map do |t| - m.predict(prep_text(t), k, threshold).map do |v| + m.predict(prep_text(t), k, threshold).to_h do |v| [remove_prefix(v[1]), v[0]] - end.to_h + end end multiple ? result : result.first end def test(x, y = nil, k: 1) - input = input_path(x, y) + input, _ref = input_path(x, y) res = m.test(input, k) { examples: res[0], precision: res[1], recall: res[2] } end # TODO support options def quantize - m.quantize({}) + a = Ext::Args.new + m.quantize(a) end def labels(include_freq: false) labels, freqs = m.labels labels.map! { |v| remove_prefix(v) } @@ -83,20 +86,20 @@ private def input_path(x, y) if x.is_a?(String) raise ArgumentError, "Cannot pass y with file" if y - x + [x, nil] else tempfile = Tempfile.new("fasttext") x.zip(y) do |xi, yi| parts = Array(yi).map { |label| "__label__" + label } parts << xi.gsub("\n", " ") # replace newlines in document tempfile.write(parts.join(" ")) tempfile.write("\n") end tempfile.close - tempfile.path + [tempfile.path, tempfile] end end def remove_prefix(label) label.sub(label_prefix, "")