ext/fasttext/ext.cpp in fasttext-0.1.2 vs ext/fasttext/ext.cpp in fasttext-0.1.3

- old
+ new

@@ -1,21 +1,36 @@ +// stdlib +#include <cmath> +#include <iterator> +#include <sstream> +#include <stdexcept> + +// fasttext #include <args.h> +#include <autotune.h> #include <densematrix.h> #include <fasttext.h> -#include <rice/Data_Type.hpp> -#include <rice/Constructor.hpp> -#include <rice/Array.hpp> -#include <rice/Hash.hpp> #include <real.h> #include <vector.h> -#include <cmath> -#include <iterator> -#include <sstream> -#include <stdexcept> -using namespace Rice; +// rice +#include <rice/Array.hpp> +#include <rice/Constructor.hpp> +#include <rice/Data_Type.hpp> +#include <rice/Hash.hpp> +using fasttext::FastText; + +using Rice::Array; +using Rice::Constructor; +using Rice::Hash; +using Rice::Module; +using Rice::Object; +using Rice::define_class_under; +using Rice::define_module; +using Rice::define_module_under; + template<> inline Object to_ruby<std::vector<std::pair<fasttext::real, std::string>>>(std::vector<std::pair<fasttext::real, std::string>> const & x) { Array ret; @@ -102,12 +117,22 @@ a.verbose = from_ruby<int>(value); } else if (name == "pretrained_vectors") { a.pretrainedVectors = from_ruby<std::string>(value); } else if (name == "save_output") { a.saveOutput = from_ruby<bool>(value); - // } else if (name == "seed") { - // a.seed = from_ruby<int>(value); + } else if (name == "seed") { + a.seed = from_ruby<int>(value); + } else if (name == "autotune_validation_file") { + a.autotuneValidationFile = from_ruby<std::string>(value); + } else if (name == "autotune_metric") { + a.autotuneMetric = from_ruby<std::string>(value); + } else if (name == "autotune_predictions") { + a.autotunePredictions = from_ruby<int>(value); + } else if (name == "autotune_duration") { + a.autotuneDuration = from_ruby<int>(value); + } else if (name == "autotune_model_size") { + a.autotuneModelSize = from_ruby<std::string>(value); } else { throw std::invalid_argument("Unknown argument: " + name); } } return a; @@ -117,15 +142,15 @@ void Init_ext() { Module rb_mFastText = define_module("FastText"); Module rb_mExt = define_module_under(rb_mFastText, "Ext"); - define_class_under<fasttext::FastText>(rb_mExt, "Model") - .define_constructor(Constructor<fasttext::FastText>()) + define_class_under<FastText>(rb_mExt, "Model") + .define_constructor(Constructor<FastText>()) .define_method( "words", - *[](fasttext::FastText& m) { + *[](FastText& m) { std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word); Array vocab_list; Array vocab_freq; @@ -139,11 +164,11 @@ ret.push(vocab_freq); return ret; }) .define_method( "labels", - *[](fasttext::FastText& m) { + *[](FastText& m) { std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label); Array vocab_list; Array vocab_freq; @@ -157,16 +182,16 @@ ret.push(vocab_freq); return ret; }) .define_method( "test", - *[](fasttext::FastText& m, const std::string filename, int32_t k) { + *[](FastText& m, const std::string filename, int32_t k) { std::ifstream ifs(filename); if (!ifs.is_open()) { throw std::invalid_argument("Test file cannot be opened!"); } - fasttext::Meter meter; + fasttext::Meter meter(false); m.test(ifs, k, 0.0, meter); ifs.close(); Array ret; ret.push(meter.nexamples()); @@ -174,36 +199,36 @@ ret.push(meter.recall()); return ret; }) .define_method( "load_model", - *[](fasttext::FastText& m, std::string s) { m.loadModel(s); }) + *[](FastText& m, std::string s) { m.loadModel(s); }) .define_method( "save_model", - *[](fasttext::FastText& m, std::string s) { m.saveModel(s); }) - .define_method("dimension", &fasttext::FastText::getDimension) - .define_method("quantized?", &fasttext::FastText::isQuant) - .define_method("word_id", &fasttext::FastText::getWordId) - .define_method("subword_id", &fasttext::FastText::getSubwordId) + *[](FastText& m, std::string s) { m.saveModel(s); }) + .define_method("dimension", &FastText::getDimension) + .define_method("quantized?", &FastText::isQuant) + .define_method("word_id", &FastText::getWordId) + .define_method("subword_id", &FastText::getSubwordId) .define_method( "predict", - *[](fasttext::FastText& m, const std::string text, int32_t k, float threshold) { + *[](FastText& m, const std::string text, int32_t k, float threshold) { std::stringstream ioss(text); std::vector<std::pair<fasttext::real, std::string>> predictions; m.predictLine(ioss, predictions, k, threshold); return predictions; }) .define_method( "nearest_neighbors", - *[](fasttext::FastText& m, const std::string& word, int32_t k) { + *[](FastText& m, const std::string& word, int32_t k) { return m.getNN(word, k); }) - .define_method("analogies", &fasttext::FastText::getAnalogies) - .define_method("ngram_vectors", &fasttext::FastText::getNgramVectors) + .define_method("analogies", &FastText::getAnalogies) + .define_method("ngram_vectors", &FastText::getNgramVectors) .define_method( "word_vector", - *[](fasttext::FastText& m, const std::string word) { + *[](FastText& m, const std::string word) { int dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getWordVector(vec, word); float* data = vec.data(); Array ret; @@ -212,11 +237,11 @@ } return ret; }) .define_method( "subwords", - *[](fasttext::FastText& m, const std::string word) { + *[](FastText& m, const std::string word) { std::vector<std::string> subwords; std::vector<int32_t> ngrams; std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); d->getSubwords(word, ngrams, subwords); @@ -226,11 +251,11 @@ } return ret; }) .define_method( "sentence_vector", - *[](fasttext::FastText& m, const std::string text) { + *[](FastText& m, const std::string text) { std::istringstream in(text); int dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getSentenceVector(in, vec); float* data = vec.data(); @@ -240,24 +265,30 @@ } return ret; }) .define_method( "train", - *[](fasttext::FastText& m, Hash h) { - m.train(buildArgs(h)); + *[](FastText& m, Hash h) { + auto a = buildArgs(h); + if (a.hasAutotune()) { + fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {})); + autotune.train(a); + } else { + m.train(a); + } }) .define_method( "quantize", - *[](fasttext::FastText& m, Hash h) { + *[](FastText& m, Hash h) { m.quantize(buildArgs(h)); }) .define_method( "supervised?", - *[](fasttext::FastText& m) { + *[](FastText& m) { return m.getArgs().model == fasttext::model_name::sup; }) .define_method( "label_prefix", - *[](fasttext::FastText& m) { + *[](FastText& m) { return m.getArgs().label; }); }