ext/fasttext/ext.cpp in fasttext-0.2.1 vs ext/fasttext/ext.cpp in fasttext-0.2.2
- old
+ new
@@ -14,17 +14,16 @@
// rice
#include <rice/rice.hpp>
#include <rice/stl.hpp>
+using fasttext::Args;
using fasttext::FastText;
using Rice::Array;
using Rice::Constructor;
-using Rice::Hash;
using Rice::Module;
-using Rice::Object;
using Rice::define_class_under;
using Rice::define_module;
using Rice::define_module_under;
namespace Rice::detail
@@ -45,107 +44,73 @@
return ret;
}
};
}
-fasttext::Args buildArgs(Hash h) {
- fasttext::Args a;
-
- for (const auto& it : h)
- {
- auto name = it.key.to_s().str();
- auto value = (it.value).value();
-
- if (name == "input") {
- a.input = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "output") {
- a.output = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "lr") {
- a.lr = Rice::detail::From_Ruby<double>().convert(value);
- } else if (name == "lr_update_rate") {
- a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "dim") {
- a.dim = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "ws") {
- a.ws = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "epoch") {
- a.epoch = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "min_count") {
- a.minCount = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "min_count_label") {
- a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "neg") {
- a.neg = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "word_ngrams") {
- a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "loss") {
- std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
- if (str == "softmax") {
- a.loss = fasttext::loss_name::softmax;
- } else if (str == "ns") {
- a.loss = fasttext::loss_name::ns;
- } else if (str == "hs") {
- a.loss = fasttext::loss_name::hs;
- } else if (str == "ova") {
- a.loss = fasttext::loss_name::ova;
- } else {
- throw std::invalid_argument("Unknown loss: " + str);
- }
- } else if (name == "model") {
- std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
- if (str == "supervised") {
- a.model = fasttext::model_name::sup;
- } else if (str == "skipgram") {
- a.model = fasttext::model_name::sg;
- } else if (str == "cbow") {
- a.model = fasttext::model_name::cbow;
- } else {
- throw std::invalid_argument("Unknown model: " + str);
- }
- } else if (name == "bucket") {
- a.bucket = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "minn") {
- a.minn = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "maxn") {
- a.maxn = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "thread") {
- a.thread = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "t") {
- a.t = Rice::detail::From_Ruby<double>().convert(value);
- } else if (name == "label_prefix") {
- a.label = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "verbose") {
- a.verbose = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "pretrained_vectors") {
- a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "save_output") {
- a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value);
- } else if (name == "seed") {
- a.seed = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "autotune_validation_file") {
- a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "autotune_metric") {
- a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value);
- } else if (name == "autotune_predictions") {
- a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "autotune_duration") {
- a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value);
- } else if (name == "autotune_model_size") {
- a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value);
- } else {
- throw std::invalid_argument("Unknown argument: " + name);
- }
- }
- return a;
-}
-
extern "C"
void Init_ext()
{
Module rb_mFastText = define_module("FastText");
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
+ define_class_under<Args>(rb_mExt, "Args")
+ .define_constructor(Constructor<Args>())
+ .define_attr("input", &Args::input)
+ .define_attr("output", &Args::output)
+ .define_attr("lr", &Args::lr)
+ .define_attr("lr_update_rate", &Args::lrUpdateRate)
+ .define_attr("dim", &Args::dim)
+ .define_attr("ws", &Args::ws)
+ .define_attr("epoch", &Args::epoch)
+ .define_attr("min_count", &Args::minCount)
+ .define_attr("min_count_label", &Args::minCountLabel)
+ .define_attr("neg", &Args::neg)
+ .define_attr("word_ngrams", &Args::wordNgrams)
+ .define_method(
+ "loss=",
+ [](Args& a, const std::string& str) {
+ if (str == "softmax") {
+ a.loss = fasttext::loss_name::softmax;
+ } else if (str == "ns") {
+ a.loss = fasttext::loss_name::ns;
+ } else if (str == "hs") {
+ a.loss = fasttext::loss_name::hs;
+ } else if (str == "ova") {
+ a.loss = fasttext::loss_name::ova;
+ } else {
+ throw std::invalid_argument("Unknown loss: " + str);
+ }
+ })
+ .define_method(
+ "model=",
+ [](Args& a, const std::string& str) {
+ if (str == "supervised") {
+ a.model = fasttext::model_name::sup;
+ } else if (str == "skipgram") {
+ a.model = fasttext::model_name::sg;
+ } else if (str == "cbow") {
+ a.model = fasttext::model_name::cbow;
+ } else {
+ throw std::invalid_argument("Unknown model: " + str);
+ }
+ })
+ .define_attr("bucket", &Args::bucket)
+ .define_attr("minn", &Args::minn)
+ .define_attr("maxn", &Args::maxn)
+ .define_attr("thread", &Args::thread)
+ .define_attr("t", &Args::t)
+ .define_attr("label_prefix", &Args::label)
+ .define_attr("verbose", &Args::verbose)
+ .define_attr("pretrained_vectors", &Args::pretrainedVectors)
+ .define_attr("save_output", &Args::saveOutput)
+ .define_attr("seed", &Args::seed)
+ .define_attr("autotune_validation_file", &Args::autotuneValidationFile)
+ .define_attr("autotune_metric", &Args::autotuneMetric)
+ .define_attr("autotune_predictions", &Args::autotunePredictions)
+ .define_attr("autotune_duration", &Args::autotuneDuration)
+ .define_attr("autotune_model_size", &Args::autotuneModelSize);
+
define_class_under<FastText>(rb_mExt, "Model")
.define_constructor(Constructor<FastText>())
.define_method(
"words",
[](FastText& m) {
@@ -229,17 +194,16 @@
.define_method("analogies", &FastText::getAnalogies)
// .define_method("ngram_vectors", &FastText::getNgramVectors)
.define_method(
"word_vector",
[](FastText& m, const std::string& word) {
- int dimension = m.getDimension();
+ auto dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getWordVector(vec, word);
- float* data = vec.data();
Array ret;
- for (int i = 0; i < dimension; i++) {
- ret.push(data[i]);
+ for (size_t i = 0; i < vec.size(); i++) {
+ ret.push(vec[i]);
}
return ret;
})
.define_method(
"subwords",
@@ -257,34 +221,32 @@
})
.define_method(
"sentence_vector",
[](FastText& m, const std::string& text) {
std::istringstream in(text);
- int dimension = m.getDimension();
+ auto dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getSentenceVector(in, vec);
- float* data = vec.data();
Array ret;
- for (int i = 0; i < dimension; i++) {
- ret.push(data[i]);
+ for (size_t i = 0; i < vec.size(); i++) {
+ ret.push(vec[i]);
}
return ret;
})
.define_method(
"train",
- [](FastText& m, Hash h) {
- auto a = buildArgs(h);
+ [](FastText& m, Args& a) {
if (a.hasAutotune()) {
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
autotune.train(a);
} else {
m.train(a);
}
})
.define_method(
"quantize",
- [](FastText& m, Hash h) {
- m.quantize(buildArgs(h));
+ [](FastText& m, Args& a) {
+ m.quantize(a);
})
.define_method(
"supervised?",
[](FastText& m) {
return m.getArgs().model == fasttext::model_name::sup;