// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_ #define DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_ #include "../array.h" #include "../graph_cuts/min_cut.h" #include "svm.h" #include "cross_validate_graph_labeling_trainer_abstract.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename graph_labeler, typename graph_type > matrix test_graph_labeling_function ( const graph_labeler& labeler, const dlib::array& samples, const std::vector >& labels, const std::vector >& losses ) { #ifdef ENABLE_ASSERTS std::string reason_for_failure; DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) , "\t matrix test_graph_labeling_function()" << "\n\t invalid inputs were given to this function" << "\n\t samples.size(): " << samples.size() << "\n\t reason_for_failure: " << reason_for_failure ); DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && all_values_are_nonnegative(losses) == true, "\t matrix test_graph_labeling_function()" << "\n\t Invalid inputs were given to this function." << "\n\t labels.size(): " << labels.size() << "\n\t losses.size(): " << losses.size() << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) ); #endif std::vector temp; double num_pos_correct = 0; double num_pos = 0; double num_neg_correct = 0; double num_neg = 0; for (unsigned long i = 0; i < samples.size(); ++i) { labeler(samples[i], temp); for (unsigned long j = 0; j < labels[i].size(); ++j) { // What is the loss for this example? It's just 1 unless we have a // per example loss vector. const double loss = (losses.size() == 0) ? 1.0 : losses[i][j]; if (labels[i][j]) { num_pos += loss; if (temp[j]) num_pos_correct += loss; } else { num_neg += loss; if (!temp[j]) num_neg_correct += loss; } } } matrix res; if (num_pos != 0) res(0) = num_pos_correct/num_pos; else res(0) = 1; if (num_neg != 0) res(1) = num_neg_correct/num_neg; else res(1) = 1; return res; } template < typename graph_labeler, typename graph_type > matrix test_graph_labeling_function ( const graph_labeler& labeler, const dlib::array& samples, const std::vector >& labels ) { std::vector > losses; return test_graph_labeling_function(labeler, samples, labels, losses); } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename graph_type > matrix cross_validate_graph_labeling_trainer ( const trainer_type& trainer, const dlib::array& samples, const std::vector >& labels, const std::vector >& losses, const long folds ) { #ifdef ENABLE_ASSERTS std::string reason_for_failure; DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure), "\t matrix cross_validate_graph_labeling_trainer()" << "\n\t invalid inputs were given to this function" << "\n\t samples.size(): " << samples.size() << "\n\t reason_for_failure: " << reason_for_failure ); DLIB_ASSERT( 1 < folds && folds <= static_cast(samples.size()), "\t matrix cross_validate_graph_labeling_trainer()" << "\n\t invalid inputs were given to this function" << "\n\t folds: " << folds ); DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && all_values_are_nonnegative(losses) == true, "\t matrix cross_validate_graph_labeling_trainer()" << "\n\t Invalid inputs were given to this function." << "\n\t labels.size(): " << labels.size() << "\n\t losses.size(): " << losses.size() << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) ); #endif typedef std::vector label_type; const long num_in_test = samples.size()/folds; const long num_in_train = samples.size() - num_in_test; dlib::array samples_test, samples_train; std::vector labels_test, labels_train; std::vector > losses_test, losses_train; long next_test_idx = 0; std::vector temp; double num_pos_correct = 0; double num_pos = 0; double num_neg_correct = 0; double num_neg = 0; graph_type gtemp; for (long i = 0; i < folds; ++i) { samples_test.clear(); labels_test.clear(); losses_test.clear(); samples_train.clear(); labels_train.clear(); losses_train.clear(); // load up the test samples for (long cnt = 0; cnt < num_in_test; ++cnt) { copy_graph(samples[next_test_idx], gtemp); samples_test.push_back(gtemp); labels_test.push_back(labels[next_test_idx]); if (losses.size() != 0) losses_test.push_back(losses[next_test_idx]); next_test_idx = (next_test_idx + 1)%samples.size(); } // load up the training samples long next = next_test_idx; for (long cnt = 0; cnt < num_in_train; ++cnt) { copy_graph(samples[next], gtemp); samples_train.push_back(gtemp); labels_train.push_back(labels[next]); if (losses.size() != 0) losses_train.push_back(losses[next]); next = (next + 1)%samples.size(); } const typename trainer_type::trained_function_type& labeler = trainer.train(samples_train,labels_train,losses_train); // check how good labeler is on the test data for (unsigned long i = 0; i < samples_test.size(); ++i) { labeler(samples_test[i], temp); for (unsigned long j = 0; j < labels_test[i].size(); ++j) { // What is the loss for this example? It's just 1 unless we have a // per example loss vector. const double loss = (losses_test.size() == 0) ? 1.0 : losses_test[i][j]; if (labels_test[i][j]) { num_pos += loss; if (temp[j]) num_pos_correct += loss; } else { num_neg += loss; if (!temp[j]) num_neg_correct += loss; } } } } // for (long i = 0; i < folds; ++i) matrix res; if (num_pos != 0) res(0) = num_pos_correct/num_pos; else res(0) = 1; if (num_neg != 0) res(1) = num_neg_correct/num_neg; else res(1) = 1; return res; } template < typename trainer_type, typename graph_type > matrix cross_validate_graph_labeling_trainer ( const trainer_type& trainer, const dlib::array& samples, const std::vector >& labels, const long folds ) { std::vector > losses; return cross_validate_graph_labeling_trainer(trainer, samples, labels, losses, folds); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_CROSS_VALIDATE_GRAPh_LABELING_TRAINER_Hh_