ext/fasttext/ext.cpp in fasttext-0.2.1 vs ext/fasttext/ext.cpp in fasttext-0.2.2

- old
+ new

@@ -14,17 +14,16 @@ // rice #include <rice/rice.hpp> #include <rice/stl.hpp> +using fasttext::Args; using fasttext::FastText; using Rice::Array; using Rice::Constructor; -using Rice::Hash; using Rice::Module; -using Rice::Object; using Rice::define_class_under; using Rice::define_module; using Rice::define_module_under; namespace Rice::detail @@ -45,107 +44,73 @@ return ret; } }; } -fasttext::Args buildArgs(Hash h) { - fasttext::Args a; - - for (const auto& it : h) - { - auto name = it.key.to_s().str(); - auto value = (it.value).value(); - - if (name == "input") { - a.input = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "output") { - a.output = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "lr") { - a.lr = Rice::detail::From_Ruby<double>().convert(value); - } else if (name == "lr_update_rate") { - a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "dim") { - a.dim = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "ws") { - a.ws = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "epoch") { - a.epoch = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "min_count") { - a.minCount = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "min_count_label") { - a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "neg") { - a.neg = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "word_ngrams") { - a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "loss") { - std::string str = Rice::detail::From_Ruby<std::string>().convert(value); - if (str == "softmax") { - a.loss = fasttext::loss_name::softmax; - } else if (str == "ns") { - a.loss = fasttext::loss_name::ns; - } else if (str == "hs") { - a.loss = fasttext::loss_name::hs; - } else if (str == "ova") { - a.loss = fasttext::loss_name::ova; - } else { - throw std::invalid_argument("Unknown loss: " + str); - } - } else if (name == "model") { - std::string str = Rice::detail::From_Ruby<std::string>().convert(value); - if (str == "supervised") { - a.model = fasttext::model_name::sup; - } else if (str == "skipgram") { - a.model = fasttext::model_name::sg; - } else if (str == "cbow") { - a.model = fasttext::model_name::cbow; - } else { - throw std::invalid_argument("Unknown model: " + str); - } - } else if (name == "bucket") { - a.bucket = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "minn") { - a.minn = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "maxn") { - a.maxn = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "thread") { - a.thread = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "t") { - a.t = Rice::detail::From_Ruby<double>().convert(value); - } else if (name == "label_prefix") { - a.label = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "verbose") { - a.verbose = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "pretrained_vectors") { - a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "save_output") { - a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value); - } else if (name == "seed") { - a.seed = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "autotune_validation_file") { - a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "autotune_metric") { - a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value); - } else if (name == "autotune_predictions") { - a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "autotune_duration") { - a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value); - } else if (name == "autotune_model_size") { - a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value); - } else { - throw std::invalid_argument("Unknown argument: " + name); - } - } - return a; -} - extern "C" void Init_ext() { Module rb_mFastText = define_module("FastText"); Module rb_mExt = define_module_under(rb_mFastText, "Ext"); + define_class_under<Args>(rb_mExt, "Args") + .define_constructor(Constructor<Args>()) + .define_attr("input", &Args::input) + .define_attr("output", &Args::output) + .define_attr("lr", &Args::lr) + .define_attr("lr_update_rate", &Args::lrUpdateRate) + .define_attr("dim", &Args::dim) + .define_attr("ws", &Args::ws) + .define_attr("epoch", &Args::epoch) + .define_attr("min_count", &Args::minCount) + .define_attr("min_count_label", &Args::minCountLabel) + .define_attr("neg", &Args::neg) + .define_attr("word_ngrams", &Args::wordNgrams) + .define_method( + "loss=", + [](Args& a, const std::string& str) { + if (str == "softmax") { + a.loss = fasttext::loss_name::softmax; + } else if (str == "ns") { + a.loss = fasttext::loss_name::ns; + } else if (str == "hs") { + a.loss = fasttext::loss_name::hs; + } else if (str == "ova") { + a.loss = fasttext::loss_name::ova; + } else { + throw std::invalid_argument("Unknown loss: " + str); + } + }) + .define_method( + "model=", + [](Args& a, const std::string& str) { + if (str == "supervised") { + a.model = fasttext::model_name::sup; + } else if (str == "skipgram") { + a.model = fasttext::model_name::sg; + } else if (str == "cbow") { + a.model = fasttext::model_name::cbow; + } else { + throw std::invalid_argument("Unknown model: " + str); + } + }) + .define_attr("bucket", &Args::bucket) + .define_attr("minn", &Args::minn) + .define_attr("maxn", &Args::maxn) + .define_attr("thread", &Args::thread) + .define_attr("t", &Args::t) + .define_attr("label_prefix", &Args::label) + .define_attr("verbose", &Args::verbose) + .define_attr("pretrained_vectors", &Args::pretrainedVectors) + .define_attr("save_output", &Args::saveOutput) + .define_attr("seed", &Args::seed) + .define_attr("autotune_validation_file", &Args::autotuneValidationFile) + .define_attr("autotune_metric", &Args::autotuneMetric) + .define_attr("autotune_predictions", &Args::autotunePredictions) + .define_attr("autotune_duration", &Args::autotuneDuration) + .define_attr("autotune_model_size", &Args::autotuneModelSize); + define_class_under<FastText>(rb_mExt, "Model") .define_constructor(Constructor<FastText>()) .define_method( "words", [](FastText& m) { @@ -229,17 +194,16 @@ .define_method("analogies", &FastText::getAnalogies) // .define_method("ngram_vectors", &FastText::getNgramVectors) .define_method( "word_vector", [](FastText& m, const std::string& word) { - int dimension = m.getDimension(); + auto dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getWordVector(vec, word); - float* data = vec.data(); Array ret; - for (int i = 0; i < dimension; i++) { - ret.push(data[i]); + for (size_t i = 0; i < vec.size(); i++) { + ret.push(vec[i]); } return ret; }) .define_method( "subwords", @@ -257,34 +221,32 @@ }) .define_method( "sentence_vector", [](FastText& m, const std::string& text) { std::istringstream in(text); - int dimension = m.getDimension(); + auto dimension = m.getDimension(); fasttext::Vector vec = fasttext::Vector(dimension); m.getSentenceVector(in, vec); - float* data = vec.data(); Array ret; - for (int i = 0; i < dimension; i++) { - ret.push(data[i]); + for (size_t i = 0; i < vec.size(); i++) { + ret.push(vec[i]); } return ret; }) .define_method( "train", - [](FastText& m, Hash h) { - auto a = buildArgs(h); + [](FastText& m, Args& a) { if (a.hasAutotune()) { fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {})); autotune.train(a); } else { m.train(a); } }) .define_method( "quantize", - [](FastText& m, Hash h) { - m.quantize(buildArgs(h)); + [](FastText& m, Args& a) { + m.quantize(a); }) .define_method( "supervised?", [](FastText& m) { return m.getArgs().model == fasttext::model_name::sup;