vendor/fastText/src/meter.h in fasttext-0.1.2 vs vendor/fastText/src/meter.h in fasttext-0.1.3
- old
+ new
@@ -20,12 +20,13 @@
class Meter {
struct Metrics {
uint64_t gold;
uint64_t predicted;
uint64_t predictedGold;
+ mutable std::vector<std::pair<real, real>> scoreVsTrue;
- Metrics() : gold(0), predicted(0), predictedGold(0) {}
+ Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
double precision() const {
if (predicted == 0) {
return std::numeric_limits<double>::quiet_NaN();
}
@@ -41,29 +42,50 @@
if (predicted + gold == 0) {
return std::numeric_limits<double>::quiet_NaN();
}
return 2 * predictedGold / double(predicted + gold);
}
+
+ std::vector<std::pair<real, real>> getScoreVsTrue() {
+ return scoreVsTrue;
+ }
};
+ std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
+ int32_t labelId) const;
public:
- Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
+ Meter() = delete;
+ explicit Meter(bool falseNegativeLabels)
+ : metrics_(),
+ nexamples_(0),
+ labelMetrics_(),
+ falseNegativeLabels_(falseNegativeLabels) {}
void log(const std::vector<int32_t>& labels, const Predictions& predictions);
double precision(int32_t);
double recall(int32_t);
double f1Score(int32_t);
+ std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
+ double precisionAtRecall(int32_t labelId, double recall) const;
+ double precisionAtRecall(double recall) const;
+ double recallAtPrecision(int32_t labelId, double recall) const;
+ double recallAtPrecision(double recall) const;
+ std::vector<std::pair<double, double>> precisionRecallCurve(
+ int32_t labelId) const;
+ std::vector<std::pair<double, double>> precisionRecallCurve() const;
double precision() const;
double recall() const;
+ double f1Score() const;
uint64_t nexamples() const {
return nexamples_;
}
void writeGeneralMetrics(std::ostream& out, int32_t k) const;
private:
Metrics metrics_{};
uint64_t nexamples_;
std::unordered_map<int32_t, Metrics> labelMetrics_;
+ bool falseNegativeLabels_;
};
} // namespace fasttext