ext/fasttext/ext.cpp in fasttext-0.1.2 vs ext/fasttext/ext.cpp in fasttext-0.1.3
- old
+ new
@@ -1,21 +1,36 @@
+// stdlib
+#include <cmath>
+#include <iterator>
+#include <sstream>
+#include <stdexcept>
+
+// fasttext
#include <args.h>
+#include <autotune.h>
#include <densematrix.h>
#include <fasttext.h>
-#include <rice/Data_Type.hpp>
-#include <rice/Constructor.hpp>
-#include <rice/Array.hpp>
-#include <rice/Hash.hpp>
#include <real.h>
#include <vector.h>
-#include <cmath>
-#include <iterator>
-#include <sstream>
-#include <stdexcept>
-using namespace Rice;
+// rice
+#include <rice/Array.hpp>
+#include <rice/Constructor.hpp>
+#include <rice/Data_Type.hpp>
+#include <rice/Hash.hpp>
+using fasttext::FastText;
+
+using Rice::Array;
+using Rice::Constructor;
+using Rice::Hash;
+using Rice::Module;
+using Rice::Object;
+using Rice::define_class_under;
+using Rice::define_module;
+using Rice::define_module_under;
+
template<>
inline
Object to_ruby<std::vector<std::pair<fasttext::real, std::string>>>(std::vector<std::pair<fasttext::real, std::string>> const & x)
{
Array ret;
@@ -102,12 +117,22 @@
a.verbose = from_ruby<int>(value);
} else if (name == "pretrained_vectors") {
a.pretrainedVectors = from_ruby<std::string>(value);
} else if (name == "save_output") {
a.saveOutput = from_ruby<bool>(value);
- // } else if (name == "seed") {
- // a.seed = from_ruby<int>(value);
+ } else if (name == "seed") {
+ a.seed = from_ruby<int>(value);
+ } else if (name == "autotune_validation_file") {
+ a.autotuneValidationFile = from_ruby<std::string>(value);
+ } else if (name == "autotune_metric") {
+ a.autotuneMetric = from_ruby<std::string>(value);
+ } else if (name == "autotune_predictions") {
+ a.autotunePredictions = from_ruby<int>(value);
+ } else if (name == "autotune_duration") {
+ a.autotuneDuration = from_ruby<int>(value);
+ } else if (name == "autotune_model_size") {
+ a.autotuneModelSize = from_ruby<std::string>(value);
} else {
throw std::invalid_argument("Unknown argument: " + name);
}
}
return a;
@@ -117,15 +142,15 @@
void Init_ext()
{
Module rb_mFastText = define_module("FastText");
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
- define_class_under<fasttext::FastText>(rb_mExt, "Model")
- .define_constructor(Constructor<fasttext::FastText>())
+ define_class_under<FastText>(rb_mExt, "Model")
+ .define_constructor(Constructor<FastText>())
.define_method(
"words",
- *[](fasttext::FastText& m) {
+ *[](FastText& m) {
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
Array vocab_list;
Array vocab_freq;
@@ -139,11 +164,11 @@
ret.push(vocab_freq);
return ret;
})
.define_method(
"labels",
- *[](fasttext::FastText& m) {
+ *[](FastText& m) {
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
Array vocab_list;
Array vocab_freq;
@@ -157,16 +182,16 @@
ret.push(vocab_freq);
return ret;
})
.define_method(
"test",
- *[](fasttext::FastText& m, const std::string filename, int32_t k) {
+ *[](FastText& m, const std::string filename, int32_t k) {
std::ifstream ifs(filename);
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
- fasttext::Meter meter;
+ fasttext::Meter meter(false);
m.test(ifs, k, 0.0, meter);
ifs.close();
Array ret;
ret.push(meter.nexamples());
@@ -174,36 +199,36 @@
ret.push(meter.recall());
return ret;
})
.define_method(
"load_model",
- *[](fasttext::FastText& m, std::string s) { m.loadModel(s); })
+ *[](FastText& m, std::string s) { m.loadModel(s); })
.define_method(
"save_model",
- *[](fasttext::FastText& m, std::string s) { m.saveModel(s); })
- .define_method("dimension", &fasttext::FastText::getDimension)
- .define_method("quantized?", &fasttext::FastText::isQuant)
- .define_method("word_id", &fasttext::FastText::getWordId)
- .define_method("subword_id", &fasttext::FastText::getSubwordId)
+ *[](FastText& m, std::string s) { m.saveModel(s); })
+ .define_method("dimension", &FastText::getDimension)
+ .define_method("quantized?", &FastText::isQuant)
+ .define_method("word_id", &FastText::getWordId)
+ .define_method("subword_id", &FastText::getSubwordId)
.define_method(
"predict",
- *[](fasttext::FastText& m, const std::string text, int32_t k, float threshold) {
+ *[](FastText& m, const std::string text, int32_t k, float threshold) {
std::stringstream ioss(text);
std::vector<std::pair<fasttext::real, std::string>> predictions;
m.predictLine(ioss, predictions, k, threshold);
return predictions;
})
.define_method(
"nearest_neighbors",
- *[](fasttext::FastText& m, const std::string& word, int32_t k) {
+ *[](FastText& m, const std::string& word, int32_t k) {
return m.getNN(word, k);
})
- .define_method("analogies", &fasttext::FastText::getAnalogies)
- .define_method("ngram_vectors", &fasttext::FastText::getNgramVectors)
+ .define_method("analogies", &FastText::getAnalogies)
+ .define_method("ngram_vectors", &FastText::getNgramVectors)
.define_method(
"word_vector",
- *[](fasttext::FastText& m, const std::string word) {
+ *[](FastText& m, const std::string word) {
int dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getWordVector(vec, word);
float* data = vec.data();
Array ret;
@@ -212,11 +237,11 @@
}
return ret;
})
.define_method(
"subwords",
- *[](fasttext::FastText& m, const std::string word) {
+ *[](FastText& m, const std::string word) {
std::vector<std::string> subwords;
std::vector<int32_t> ngrams;
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
d->getSubwords(word, ngrams, subwords);
@@ -226,11 +251,11 @@
}
return ret;
})
.define_method(
"sentence_vector",
- *[](fasttext::FastText& m, const std::string text) {
+ *[](FastText& m, const std::string text) {
std::istringstream in(text);
int dimension = m.getDimension();
fasttext::Vector vec = fasttext::Vector(dimension);
m.getSentenceVector(in, vec);
float* data = vec.data();
@@ -240,24 +265,30 @@
}
return ret;
})
.define_method(
"train",
- *[](fasttext::FastText& m, Hash h) {
- m.train(buildArgs(h));
+ *[](FastText& m, Hash h) {
+ auto a = buildArgs(h);
+ if (a.hasAutotune()) {
+ fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
+ autotune.train(a);
+ } else {
+ m.train(a);
+ }
})
.define_method(
"quantize",
- *[](fasttext::FastText& m, Hash h) {
+ *[](FastText& m, Hash h) {
m.quantize(buildArgs(h));
})
.define_method(
"supervised?",
- *[](fasttext::FastText& m) {
+ *[](FastText& m) {
return m.getArgs().model == fasttext::model_name::sup;
})
.define_method(
"label_prefix",
- *[](fasttext::FastText& m) {
+ *[](FastText& m) {
return m.getArgs().label;
});
}