lib/fasttext/classifier.rb in fasttext-0.1.2 vs lib/fasttext/classifier.rb in fasttext-0.1.3
- old
+ new
@@ -19,16 +19,26 @@
t: 0.0001,
label_prefix: "__label__",
verbose: 2,
pretrained_vectors: "",
save_output: false,
- # seed: 0
+ seed: 0,
+ autotune_validation_file: "",
+ autotune_metric: "f1",
+ autotune_predictions: 1,
+ autotune_duration: 60 * 5,
+ autotune_model_size: ""
}
- def fit(x, y = nil)
+ def fit(x, y = nil, autotune_set: nil)
input = input_path(x, y)
@m ||= Ext::Model.new
- m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
+ opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
+ if autotune_set
+ x, y = autotune_set
+ opts.merge!(autotune_validation_file: input_path(x, y))
+ end
+ m.train(opts)
end
def predict(text, k: 1, threshold: 0.0)
multiple = text.is_a?(Array)
text = [text] unless multiple