ext/fasttext/ext.cpp in fasttext-0.1.3 vs ext/fasttext/ext.cpp in fasttext-0.2.0
- old
+ new
@@ -11,14 +11,12 @@
#include <fasttext.h>
#include <real.h>
#include <vector.h>
// rice
-#include <rice/Array.hpp>
-#include <rice/Constructor.hpp>
-#include <rice/Data_Type.hpp>
-#include <rice/Hash.hpp>
+#include <rice/rice.hpp>
+#include <rice/stl.hpp>
using fasttext::FastText;
using Rice::Array;
using Rice::Constructor;
@@ -27,22 +25,28 @@
using Rice::Object;
using Rice::define_class_under;
using Rice::define_module;
using Rice::define_module_under;
-template<>
-inline
-Object to_ruby<std::vector<std::pair<fasttext::real, std::string>>>(std::vector<std::pair<fasttext::real, std::string>> const & x)
+namespace Rice::detail
{
- Array ret;
- for (const auto& v : x) {
- Array a;
- a.push(v.first);
- a.push(v.second);
- ret.push(a);
- }
- return ret;
+ template<>
+ class To_Ruby<std::vector<std::pair<fasttext::real, std::string>>>
+ {
+ public:
+ VALUE convert(std::vector<std::pair<fasttext::real, std::string>> const & x)
+ {
+ Array ret;
+ for (const auto& v : x) {
+ Array a;
+ a.push(v.first);
+ a.push(v.second);
+ ret.push(a);
+ }
+ return ret;
+ }
+ };
}
fasttext::Args buildArgs(Hash h) {
fasttext::Args a;
@@ -50,37 +54,37 @@
Hash::iterator it = h.begin();
Hash::iterator end = h.end();
for(; it != end; ++it)
{
- std::string name = from_ruby<std::string>(it->key.to_s());
- Object value = it->value;
+ std::string name = it->key.to_s().str();
+ VALUE value = (it->value).value();
if (name == "input") {
- a.input = from_ruby<std::string>(value);
+ a.input = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "output") {
- a.output = from_ruby<std::string>(value);
+ a.output = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "lr") {
- a.lr = from_ruby<double>(value);
+ a.lr = Rice::detail::From_Ruby<double>().convert(value);
} else if (name == "lr_update_rate") {
- a.lrUpdateRate = from_ruby<int>(value);
+ a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "dim") {
- a.dim = from_ruby<int>(value);
+ a.dim = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "ws") {
- a.ws = from_ruby<int>(value);
+ a.ws = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "epoch") {
- a.epoch = from_ruby<int>(value);
+ a.epoch = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "min_count") {
- a.minCount = from_ruby<int>(value);
+ a.minCount = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "min_count_label") {
- a.minCountLabel = from_ruby<int>(value);
+ a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "neg") {
- a.neg = from_ruby<int>(value);
+ a.neg = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "word_ngrams") {
- a.wordNgrams = from_ruby<int>(value);
+ a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "loss") {
- std::string str = from_ruby<std::string>(value);
+ std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
if (str == "softmax") {
a.loss = fasttext::loss_name::softmax;
} else if (str == "ns") {
a.loss = fasttext::loss_name::ns;
} else if (str == "hs") {
@@ -89,50 +93,50 @@
a.loss = fasttext::loss_name::ova;
} else {
throw std::invalid_argument("Unknown loss: " + str);
}
} else if (name == "model") {
- std::string str = from_ruby<std::string>(value);
+ std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
if (str == "supervised") {
a.model = fasttext::model_name::sup;
} else if (str == "skipgram") {
a.model = fasttext::model_name::sg;
} else if (str == "cbow") {
a.model = fasttext::model_name::cbow;
} else {
throw std::invalid_argument("Unknown model: " + str);
}
} else if (name == "bucket") {
- a.bucket = from_ruby<int>(value);
+ a.bucket = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "minn") {
- a.minn = from_ruby<int>(value);
+ a.minn = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "maxn") {
- a.maxn = from_ruby<int>(value);
+ a.maxn = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "thread") {
- a.thread = from_ruby<int>(value);
+ a.thread = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "t") {
- a.t = from_ruby<double>(value);
+ a.t = Rice::detail::From_Ruby<double>().convert(value);
} else if (name == "label_prefix") {
- a.label = from_ruby<std::string>(value);
+ a.label = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "verbose") {
- a.verbose = from_ruby<int>(value);
+ a.verbose = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "pretrained_vectors") {
- a.pretrainedVectors = from_ruby<std::string>(value);
+ a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "save_output") {
- a.saveOutput = from_ruby<bool>(value);
+ a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value);
} else if (name == "seed") {
- a.seed = from_ruby<int>(value);
+ a.seed = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "autotune_validation_file") {
- a.autotuneValidationFile = from_ruby<std::string>(value);
+ a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "autotune_metric") {
- a.autotuneMetric = from_ruby<std::string>(value);
+ a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value);
} else if (name == "autotune_predictions") {
- a.autotunePredictions = from_ruby<int>(value);
+ a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "autotune_duration") {
- a.autotuneDuration = from_ruby<int>(value);
+ a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value);
} else if (name == "autotune_model_size") {
- a.autotuneModelSize = from_ruby<std::string>(value);
+ a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value);
} else {
throw std::invalid_argument("Unknown argument: " + name);
}
}
return a;
@@ -146,11 +150,11 @@
define_class_under<FastText>(rb_mExt, "Model")
.define_constructor(Constructor<FastText>())
.define_method(
"words",
- *[](FastText& m) {
+ [](FastText& m) {
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
Array vocab_list;
Array vocab_freq;
@@ -164,11 +168,11 @@
ret.push(vocab_freq);
return ret;
})
.define_method(
"labels",
- *[](FastText& m) {
+ [](FastText& m) {
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
Array vocab_list;
Array vocab_freq;
@@ -182,11 +186,11 @@
ret.push(vocab_freq);
return ret;
})
.define_method(
"test",
- *[](FastText& m, const std::string filename, int32_t k) {
+ [](FastText& m, const std::string& filename, int32_t k) {
std::ifstream ifs(filename);
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
fasttext::Meter meter(false);
@@ -199,36 +203,40 @@
ret.push(meter.recall());
return ret;
})
.define_method(
"load_model",
- *[](FastText& m, std::string s) { m.loadModel(s); })
+ [](FastText& m, const std::string& s) {
+ m.loadModel(s);
+ })
.define_method(
"save_model",
- *[](FastText& m, std::string s) { m.saveModel(s); })
+ [](FastText& m, const std::string& s) {
+ m.saveModel(s);
+ })
.define_method("dimension", &FastText::getDimension)
.define_method("quantized?", &FastText::isQuant)
.define_method("word_id", &FastText::getWordId)
.define_method("subword_id", &FastText::getSubwordId)
.define_method(
"predict",
- *[](FastText& m, const std::string text, int32_t k, float threshold) {
+ [](FastText& m, const std::string& text, int32_t k, float threshold) {
std::stringstream ioss(text);
std::vector<std::pair<fasttext::real, std::string>> predictions;
m.predictLine(ioss, predictions, k, threshold);
return predictions;
})
.define_method(
"nearest_neighbors",
- *[](FastText& m, const std::string& word, int32_t k) {
+ [](FastText& m, const std::string& word, int32_t k) {
return m.getNN(word, k);
})
.define_method("analogies", &FastText::getAnalogies)
- .define_method("ngram_vectors", &FastText::getNgramVectors)
+ // .define_method("ngram_vectors", &FastText::getNgramVectors)
.define_method(
"word_vector",
- *[](FastText& m, const std::string word) {
+ [](FastText& m, const std::string& word) {
int dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getWordVector(vec, word);
float* data = vec.data();
Array ret;
@@ -237,11 +245,11 @@
}
return ret;
})
.define_method(
"subwords",
- *[](FastText& m, const std::string word) {
+ [](FastText& m, const std::string& word) {
std::vector<std::string> subwords;
std::vector<int32_t> ngrams;
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
d->getSubwords(word, ngrams, subwords);
@@ -251,11 +259,11 @@
}
return ret;
})
.define_method(
"sentence_vector",
- *[](FastText& m, const std::string text) {
+ [](FastText& m, const std::string& text) {
std::istringstream in(text);
int dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getSentenceVector(in, vec);
float* data = vec.data();
@@ -265,30 +273,30 @@
}
return ret;
})
.define_method(
"train",
- *[](FastText& m, Hash h) {
+ [](FastText& m, Hash h) {
auto a = buildArgs(h);
if (a.hasAutotune()) {
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
autotune.train(a);
} else {
m.train(a);
}
})
.define_method(
"quantize",
- *[](FastText& m, Hash h) {
+ [](FastText& m, Hash h) {
m.quantize(buildArgs(h));
})
.define_method(
"supervised?",
- *[](FastText& m) {
+ [](FastText& m) {
return m.getArgs().model == fasttext::model_name::sup;
})
.define_method(
"label_prefix",
- *[](FastText& m) {
+ [](FastText& m) {
return m.getArgs().label;
});
}