// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_ #define DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_ #include "cross_validate_sequence_segmenter_abstract.h" #include "sequence_segmenter.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { template < typename sequence_segmenter_type, typename sequence_type > const matrix<double,1,3> raw_metrics_test_sequence_segmenter ( const sequence_segmenter_type& segmenter, const std::vector<sequence_type>& samples, const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments ) { std::vector<std::pair<unsigned long,unsigned long> > truth; std::vector<std::pair<unsigned long,unsigned long> > pred; double true_hits = 0; double total_detections = 0; double total_true_segments = 0; for (unsigned long i = 0; i < samples.size(); ++i) { segmenter.segment_sequence(samples[i], pred); truth = segments[i]; // sort the segments so they will be in the same orders std::sort(truth.begin(), truth.end()); std::sort(pred.begin(), pred.end()); total_true_segments += truth.size(); total_detections += pred.size(); unsigned long j=0,k=0; while (j < pred.size() && k < truth.size()) { if (pred[j].first == truth[k].first && pred[j].second == truth[k].second) { ++true_hits; ++j; ++k; } else if (pred[j].first < truth[k].first) { ++j; } else { ++k; } } } matrix<double,1,3> res; res = total_detections, total_true_segments, true_hits; return res; } } // ---------------------------------------------------------------------------------------- template < typename sequence_segmenter_type, typename sequence_type > const matrix<double,1,3> test_sequence_segmenter ( const sequence_segmenter_type& segmenter, const std::vector<sequence_type>& samples, const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments ) { // make sure requires clause is not broken DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true, "\tmatrix test_sequence_segmenter()" << "\n\t invalid inputs were given to this function" << "\n\t is_sequence_segmentation_problem(samples, segments): " << is_sequence_segmentation_problem(samples, segments)); const matrix<double,1,3> metrics = impl::raw_metrics_test_sequence_segmenter(segmenter, samples, segments); const double total_detections = metrics(0); const double total_true_segments = metrics(1); const double true_hits = metrics(2); const double precision = (total_detections ==0) ? 1 : true_hits/total_detections; const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments; const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall); matrix<double,1,3> res; res = precision, recall, f1; return res; } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename sequence_type > const matrix<double,1,3> cross_validate_sequence_segmenter ( const trainer_type& trainer, const std::vector<sequence_type>& samples, const std::vector<std::vector<std::pair<unsigned long,unsigned long> > >& segments, const long folds ) { // make sure requires clause is not broken DLIB_ASSERT( is_sequence_segmentation_problem(samples, segments) == true && 1 < folds && folds <= static_cast<long>(samples.size()), "\tmatrix cross_validate_sequence_segmenter()" << "\n\t invalid inputs were given to this function" << "\n\t folds: " << folds << "\n\t is_sequence_segmentation_problem(samples, segments): " << is_sequence_segmentation_problem(samples, segments)); const long num_in_test = samples.size()/folds; const long num_in_train = samples.size() - num_in_test; std::vector<sequence_type> x_test, x_train; std::vector<std::vector<std::pair<unsigned long,unsigned long> > > y_test, y_train; long next_test_idx = 0; matrix<double,1,3> metrics; metrics = 0; for (long i = 0; i < folds; ++i) { x_test.clear(); y_test.clear(); x_train.clear(); y_train.clear(); // load up the test samples for (long cnt = 0; cnt < num_in_test; ++cnt) { x_test.push_back(samples[next_test_idx]); y_test.push_back(segments[next_test_idx]); next_test_idx = (next_test_idx + 1)%samples.size(); } // load up the training samples long next = next_test_idx; for (long cnt = 0; cnt < num_in_train; ++cnt) { x_train.push_back(samples[next]); y_train.push_back(segments[next]); next = (next + 1)%samples.size(); } metrics += impl::raw_metrics_test_sequence_segmenter(trainer.train(x_train,y_train), x_test, y_test); } // for (long i = 0; i < folds; ++i) const double total_detections = metrics(0); const double total_true_segments = metrics(1); const double true_hits = metrics(2); const double precision = (total_detections ==0) ? 1 : true_hits/total_detections; const double recall = (total_true_segments==0) ? 1 : true_hits/total_true_segments; const double f1 = (precision+recall ==0) ? 0 : 2*precision*recall/(precision+recall); matrix<double,1,3> res; res = precision, recall, f1; return res; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CROSS_VALIDATE_SEQUENCE_sEGMENTER_Hh_