// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_ #define DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_ #include "cross_validate_sequence_labeler_abstract.h" #include #include "../matrix.h" #include "svm.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename sequence_labeler_type, typename sequence_type > const matrix test_sequence_labeler ( const sequence_labeler_type& labeler, const std::vector& samples, const std::vector >& labels ) { // make sure requires clause is not broken DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true, "\tmatrix test_sequence_labeler()" << "\n\t invalid inputs were given to this function" << "\n\t is_sequence_labeling_problem(samples, labels): " << is_sequence_labeling_problem(samples, labels)); matrix res(labeler.num_labels(), labeler.num_labels()); res = 0; std::vector pred; for (unsigned long i = 0; i < samples.size(); ++i) { labeler.label_sequence(samples[i], pred); for (unsigned long j = 0; j < pred.size(); ++j) { const unsigned long truth = labels[i][j]; if (truth >= static_cast(res.nr())) { // ignore labels the labeler doesn't know about. continue; } res(truth, pred[j]) += 1; } } return res; } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename sequence_type > const matrix cross_validate_sequence_labeler ( const trainer_type& trainer, const std::vector& samples, const std::vector >& labels, const long folds ) { // make sure requires clause is not broken DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true && 1 < folds && folds <= static_cast(samples.size()), "\tmatrix cross_validate_sequence_labeler()" << "\n\t invalid inputs were given to this function" << "\n\t samples.size(): " << samples.size() << "\n\t folds: " << folds << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) ); #ifdef ENABLE_ASSERTS for (unsigned long i = 0; i < labels.size(); ++i) { for (unsigned long j = 0; j < labels[i].size(); ++j) { // make sure requires clause is not broken DLIB_ASSERT(labels[i][j] < trainer.num_labels(), "\t matrix cross_validate_sequence_labeler()" << "\n\t The labels are invalid." << "\n\t labels[i][j]: " << labels[i][j] << "\n\t trainer.num_labels(): " << trainer.num_labels() << "\n\t i: " << i << "\n\t j: " << j ); } } #endif const long num_in_test = samples.size()/folds; const long num_in_train = samples.size() - num_in_test; std::vector x_test, x_train; std::vector > y_test, y_train; long next_test_idx = 0; matrix res; for (long i = 0; i < folds; ++i) { x_test.clear(); y_test.clear(); x_train.clear(); y_train.clear(); // load up the test samples for (long cnt = 0; cnt < num_in_test; ++cnt) { x_test.push_back(samples[next_test_idx]); y_test.push_back(labels[next_test_idx]); next_test_idx = (next_test_idx + 1)%samples.size(); } // load up the training samples long next = next_test_idx; for (long cnt = 0; cnt < num_in_train; ++cnt) { x_train.push_back(samples[next]); y_train.push_back(labels[next]); next = (next + 1)%samples.size(); } res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test); } // for (long i = 0; i < folds; ++i) return res; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_Hh_