ext/fasttext/ext.cpp in fasttext-0.1.3 vs ext/fasttext/ext.cpp in fasttext-0.2.0

- old
+ new

@@ -11,14 +11,12 @@ #include <fasttext.h> #include <real.h> #include <vector.h> // rice -#include <rice/Array.hpp> -#include <rice/Constructor.hpp> -#include <rice/Data_Type.hpp> -#include <rice/Hash.hpp> +#include <rice/rice.hpp> +#include <rice/stl.hpp> using fasttext::FastText; using Rice::Array; using Rice::Constructor; @@ -27,22 +25,28 @@ using Rice::Object; using Rice::define_class_under; using Rice::define_module; using Rice::define_module_under; -template<> -inline -Object to_ruby<std::vector<std::pair<fasttext::real, std::string>>>(std::vector<std::pair<fasttext::real, std::string>> const & x) +namespace Rice::detail { - Array ret; - for (const auto& v : x) { - Array a; - a.push(v.first); - a.push(v.second); - ret.push(a); - } - return ret; + template<> + class To_Ruby<std::vector<std::pair<fasttext::real, std::string>>> + { + public: + VALUE convert(std::vector<std::pair<fasttext::real, std::string>> const & x) + { + Array ret; + for (const auto& v : x) { + Array a; + a.push(v.first); + a.push(v.second); + ret.push(a); + } + return ret; + } + }; } fasttext::Args buildArgs(Hash h) { fasttext::Args a; @@ -50,37 +54,37 @@ Hash::iterator it = h.begin(); Hash::iterator end = h.end(); for(; it != end; ++it) { - std::string name = from_ruby<std::string>(it->key.to_s()); - Object value = it->value; + std::string name = it->key.to_s().str(); + VALUE value = (it->value).value(); if (name == "input") { - a.input = from_ruby<std::string>(value); + a.input = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "output") { - a.output = from_ruby<std::string>(value); + a.output = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "lr") { - a.lr = from_ruby<double>(value); + a.lr = Rice::detail::From_Ruby<double>().convert(value); } else if (name == "lr_update_rate") { - a.lrUpdateRate = from_ruby<int>(value); + a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "dim") { - a.dim = from_ruby<int>(value); + a.dim = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "ws") { - a.ws = from_ruby<int>(value); + a.ws = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "epoch") { - a.epoch = from_ruby<int>(value); + a.epoch = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "min_count") { - a.minCount = from_ruby<int>(value); + a.minCount = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "min_count_label") { - a.minCountLabel = from_ruby<int>(value); + a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "neg") { - a.neg = from_ruby<int>(value); + a.neg = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "word_ngrams") { - a.wordNgrams = from_ruby<int>(value); + a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "loss") { - std::string str = from_ruby<std::string>(value); + std::string str = Rice::detail::From_Ruby<std::string>().convert(value); if (str == "softmax") { a.loss = fasttext::loss_name::softmax; } else if (str == "ns") { a.loss = fasttext::loss_name::ns; } else if (str == "hs") { @@ -89,50 +93,50 @@ a.loss = fasttext::loss_name::ova; } else { throw std::invalid_argument("Unknown loss: " + str); } } else if (name == "model") { - std::string str = from_ruby<std::string>(value); + std::string str = Rice::detail::From_Ruby<std::string>().convert(value); if (str == "supervised") { a.model = fasttext::model_name::sup; } else if (str == "skipgram") { a.model = fasttext::model_name::sg; } else if (str == "cbow") { a.model = fasttext::model_name::cbow; } else { throw std::invalid_argument("Unknown model: " + str); } } else if (name == "bucket") { - a.bucket = from_ruby<int>(value); + a.bucket = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "minn") { - a.minn = from_ruby<int>(value); + a.minn = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "maxn") { - a.maxn = from_ruby<int>(value); + a.maxn = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "thread") { - a.thread = from_ruby<int>(value); + a.thread = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "t") { - a.t = from_ruby<double>(value); + a.t = Rice::detail::From_Ruby<double>().convert(value); } else if (name == "label_prefix") { - a.label = from_ruby<std::string>(value); + a.label = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "verbose") { - a.verbose = from_ruby<int>(value); + a.verbose = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "pretrained_vectors") { - a.pretrainedVectors = from_ruby<std::string>(value); + a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "save_output") { - a.saveOutput = from_ruby<bool>(value); + a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value); } else if (name == "seed") { - a.seed = from_ruby<int>(value); + a.seed = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "autotune_validation_file") { - a.autotuneValidationFile = from_ruby<std::string>(value); + a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "autotune_metric") { - a.autotuneMetric = from_ruby<std::string>(value); + a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value); } else if (name == "autotune_predictions") { - a.autotunePredictions = from_ruby<int>(value); + a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "autotune_duration") { - a.autotuneDuration = from_ruby<int>(value); + a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value); } else if (name == "autotune_model_size") { - a.autotuneModelSize = from_ruby<std::string>(value); + a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value); } else { throw std::invalid_argument("Unknown argument: " + name); } } return a; @@ -146,11 +150,11 @@ define_class_under<FastText>(rb_mExt, "Model") .define_constructor(Constructor<FastText>()) .define_method( "words", - *[](FastText& m) { + [](FastText& m) { std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word); Array vocab_list; Array vocab_freq; @@ -164,11 +168,11 @@ ret.push(vocab_freq); return ret; }) .define_method( "labels", - *[](FastText& m) { + [](FastText& m) { std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label); Array vocab_list; Array vocab_freq; @@ -182,11 +186,11 @@ ret.push(vocab_freq); return ret; }) .define_method( "test", - *[](FastText& m, const std::string filename, int32_t k) { + [](FastText& m, const std::string& filename, int32_t k) { std::ifstream ifs(filename); if (!ifs.is_open()) { throw std::invalid_argument("Test file cannot be opened!"); } fasttext::Meter meter(false); @@ -199,36 +203,40 @@ ret.push(meter.recall()); return ret; }) .define_method( "load_model", - *[](FastText& m, std::string s) { m.loadModel(s); }) + [](FastText& m, const std::string& s) { + m.loadModel(s); + }) .define_method( "save_model", - *[](FastText& m, std::string s) { m.saveModel(s); }) + [](FastText& m, const std::string& s) { + m.saveModel(s); + }) .define_method("dimension", &FastText::getDimension) .define_method("quantized?", &FastText::isQuant) .define_method("word_id", &FastText::getWordId) .define_method("subword_id", &FastText::getSubwordId) .define_method( "predict", - *[](FastText& m, const std::string text, int32_t k, float threshold) { + [](FastText& m, const std::string& text, int32_t k, float threshold) { std::stringstream ioss(text); std::vector<std::pair<fasttext::real, std::string>> predictions; m.predictLine(ioss, predictions, k, threshold); return predictions; }) .define_method( "nearest_neighbors", - *[](FastText& m, const std::string& word, int32_t k) { + [](FastText& m, const std::string& word, int32_t k) { return m.getNN(word, k); }) .define_method("analogies", &FastText::getAnalogies) - .define_method("ngram_vectors", &FastText::getNgramVectors) + // .define_method("ngram_vectors", &FastText::getNgramVectors) .define_method( "word_vector", - *[](FastText& m, const std::string word) { + [](FastText& m, const std::string& word) { int dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getWordVector(vec, word); float* data = vec.data(); Array ret; @@ -237,11 +245,11 @@ } return ret; }) .define_method( "subwords", - *[](FastText& m, const std::string word) { + [](FastText& m, const std::string& word) { std::vector<std::string> subwords; std::vector<int32_t> ngrams; std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary(); d->getSubwords(word, ngrams, subwords); @@ -251,11 +259,11 @@ } return ret; }) .define_method( "sentence_vector", - *[](FastText& m, const std::string text) { + [](FastText& m, const std::string& text) { std::istringstream in(text); int dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getSentenceVector(in, vec); float* data = vec.data(); @@ -265,30 +273,30 @@ } return ret; }) .define_method( "train", - *[](FastText& m, Hash h) { + [](FastText& m, Hash h) { auto a = buildArgs(h); if (a.hasAutotune()) { fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {})); autotune.train(a); } else { m.train(a); } }) .define_method( "quantize", - *[](FastText& m, Hash h) { + [](FastText& m, Hash h) { m.quantize(buildArgs(h)); }) .define_method( "supervised?", - *[](FastText& m) { + [](FastText& m) { return m.getArgs().model == fasttext::model_name::sup; }) .define_method( "label_prefix", - *[](FastText& m) { + [](FastText& m) { return m.getArgs().label; }); }