vendor/fastText/src/meter.cc in fasttext-0.1.2 vs vendor/fastText/src/meter.cc in fasttext-0.1.3
- old
+ new
@@ -14,28 +14,40 @@
#include <iomanip>
#include <limits>
namespace fasttext {
+constexpr int32_t kAllLabels = -1;
+constexpr real falseNegativeScore = -1.0;
+
void Meter::log(
const std::vector<int32_t>& labels,
const Predictions& predictions) {
nexamples_++;
metrics_.gold += labels.size();
metrics_.predicted += predictions.size();
for (const auto& prediction : predictions) {
labelMetrics_[prediction.second].predicted++;
+ real score = std::min(std::exp(prediction.first), 1.0f);
+ real gold = 0.0;
if (utils::contains(labels, prediction.second)) {
labelMetrics_[prediction.second].predictedGold++;
metrics_.predictedGold++;
+ gold = 1.0;
}
+ labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
}
- for (const auto& label : labels) {
- labelMetrics_[label].gold++;
+ if (falseNegativeLabels_) {
+ for (const auto& label : labels) {
+ labelMetrics_[label].gold++;
+ if (!utils::containsSecond(predictions, label)) {
+ labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
+ }
+ }
}
}
double Meter::precision(int32_t i) {
return labelMetrics_[i].precision();
@@ -55,14 +67,148 @@
double Meter::recall() const {
return metrics_.recall();
}
+double Meter::f1Score() const {
+ const double precision = this->precision();
+ const double recall = this->recall();
+ if (precision + recall != 0) {
+ return 2 * precision * recall / (precision + recall);
+ }
+ return std::numeric_limits<double>::quiet_NaN();
+}
+
void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
out << "N"
<< "\t" << nexamples_ << std::endl;
out << std::setprecision(3);
out << "P@" << k << "\t" << metrics_.precision() << std::endl;
out << "R@" << k << "\t" << metrics_.recall() << std::endl;
+}
+
+std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
+ int32_t labelId) const {
+ std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
+
+ const auto& v = scoreVsTrue(labelId);
+ uint64_t truePositives = 0;
+ uint64_t falsePositives = 0;
+ double lastScore = falseNegativeScore - 1.0;
+
+ for (auto it = v.rbegin(); it != v.rend(); ++it) {
+ double score = it->first;
+ double gold = it->second;
+ if (score < 0) { // only reachable recall
+ break;
+ }
+ if (gold == 1.0) {
+ truePositives++;
+ } else {
+ falsePositives++;
+ }
+ if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
+ positiveCounts.back() = {truePositives, falsePositives};
+ } else {
+ positiveCounts.emplace_back(truePositives, falsePositives);
+ }
+ lastScore = score;
+ }
+
+ return positiveCounts;
+}
+
+double Meter::precisionAtRecall(double recallQuery) const {
+ return precisionAtRecall(kAllLabels, recallQuery);
+}
+
+double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
+ const auto& precisionRecall = precisionRecallCurve(labelId);
+ double bestPrecision = 0.0;
+ std::for_each(
+ precisionRecall.begin(),
+ precisionRecall.end(),
+ [&bestPrecision, recallQuery](const std::pair<double, double>& element) {
+ if (element.second >= recallQuery) {
+ bestPrecision = std::max(bestPrecision, element.first);
+ };
+ });
+ return bestPrecision;
+}
+
+double Meter::recallAtPrecision(double precisionQuery) const {
+ return recallAtPrecision(kAllLabels, precisionQuery);
+}
+
+double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
+ const auto& precisionRecall = precisionRecallCurve(labelId);
+ double bestRecall = 0.0;
+ std::for_each(
+ precisionRecall.begin(),
+ precisionRecall.end(),
+ [&bestRecall, precisionQuery](const std::pair<double, double>& element) {
+ if (element.first >= precisionQuery) {
+ bestRecall = std::max(bestRecall, element.second);
+ };
+ });
+ return bestRecall;
+}
+
+std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
+ return precisionRecallCurve(kAllLabels);
+}
+
+std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
+ int32_t labelId) const {
+ std::vector<std::pair<double, double>> precisionRecallCurve;
+ const auto& positiveCounts = getPositiveCounts(labelId);
+ if (positiveCounts.empty()) {
+ return precisionRecallCurve;
+ }
+
+ uint64_t golds =
+ (labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
+
+ auto fullRecall = std::lower_bound(
+ positiveCounts.begin(),
+ positiveCounts.end(),
+ golds,
+ utils::compareFirstLess);
+
+ if (fullRecall != positiveCounts.end()) {
+ fullRecall = std::next(fullRecall);
+ }
+
+ for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
+ double precision = 0.0;
+ double truePositives = it->first;
+ double falsePositives = it->second;
+ if (truePositives + falsePositives != 0.0) {
+ precision = truePositives / (truePositives + falsePositives);
+ }
+ double recall = golds != 0 ? (truePositives / double(golds))
+ : std::numeric_limits<double>::quiet_NaN();
+ precisionRecallCurve.emplace_back(precision, recall);
+ }
+ precisionRecallCurve.emplace_back(1.0, 0.0);
+
+ return precisionRecallCurve;
+}
+
+std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
+ std::vector<std::pair<real, real>> ret;
+ if (labelId == kAllLabels) {
+ for (const auto& k : labelMetrics_) {
+ auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
+ ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
+ }
+ } else {
+ if (labelMetrics_.count(labelId)) {
+ ret = labelMetrics_.at(labelId).scoreVsTrue;
+ }
+ }
+ sort(ret.begin(), ret.end());
+
+ return ret;
}
} // namespace fasttext