// Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ONE_VS_ONE_TRAiNER_Hh_ #define DLIB_ONE_VS_ONE_TRAiNER_Hh_ #include "one_vs_one_trainer_abstract.h" #include "one_vs_one_decision_function.h" #include <vector> #include "../unordered_pair.h" #include "multiclass_tools.h" #include <sstream> #include <iostream> #include "../any.h" #include <map> #include <set> #include "../threads.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename any_trainer, typename label_type_ = double > class one_vs_one_trainer { public: typedef label_type_ label_type; typedef typename any_trainer::sample_type sample_type; typedef typename any_trainer::scalar_type scalar_type; typedef typename any_trainer::mem_manager_type mem_manager_type; typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type; one_vs_one_trainer ( ) : verbose(false), num_threads(4) {} void set_trainer ( const any_trainer& trainer ) { default_trainer = trainer; trainers.clear(); } void set_trainer ( const any_trainer& trainer, const label_type& l1, const label_type& l2 ) { trainers[make_unordered_pair(l1,l2)] = trainer; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_num_threads ( unsigned long num ) { num_threads = num; } unsigned long get_num_threads ( ) const { return num_threads; } struct invalid_label : public dlib::error { invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_ ) : dlib::error(msg), l1(l1_), l2(l2_) {}; virtual ~invalid_label( ) throw() {} label_type l1, l2; }; trained_function_type train ( const std::vector<sample_type>& all_samples, const std::vector<label_type>& all_labels ) const { // make sure requires clause is not broken DLIB_ASSERT(is_learning_problem(all_samples,all_labels), "\t trained_function_type one_vs_one_trainer::train(all_samples,all_labels)" << "\n\t invalid inputs were given to this function" << "\n\t all_samples.size(): " << all_samples.size() << "\n\t all_labels.size(): " << all_labels.size() ); const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels); // fill pairs with all the pairs of labels. std::vector<unordered_pair<label_type> > pairs; for (unsigned long i = 0; i < distinct_labels.size(); ++i) { for (unsigned long j = i+1; j < distinct_labels.size(); ++j) { pairs.push_back(unordered_pair<label_type>(distinct_labels[i], distinct_labels[j])); // make sure we have a trainer for this pair const typename binary_function_table::const_iterator itr = trainers.find(pairs.back()); if (itr == trainers.end() && default_trainer.is_empty()) { std::ostringstream sout; sout << "In one_vs_one_trainer, no trainer registered for the (" << pairs.back().first << ", " << pairs.back().second << ") label pair."; throw invalid_label(sout.str(), pairs.back().first, pairs.back().second); } } } // Now train on all the label pairs. parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,pairs); parallel_for(num_threads, 0, pairs.size(), helper, 500); if (helper.error_message.size() != 0) { throw dlib::error("binary trainer threw while training one vs. one classifier. Error was: " + helper.error_message); } return trained_function_type(helper.dfs); } private: typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table; struct parallel_for_helper { parallel_for_helper( const std::vector<sample_type>& all_samples_, const std::vector<label_type>& all_labels_, const any_trainer& default_trainer_, const binary_function_table& trainers_, const bool verbose_, const std::vector<unordered_pair<label_type> >& pairs_ ) : all_samples(all_samples_), all_labels(all_labels_), default_trainer(default_trainer_), trainers(trainers_), verbose(verbose_), pairs(pairs_) {} void operator()(long i) const { try { std::vector<sample_type> samples; std::vector<scalar_type> labels; const unordered_pair<label_type> p = pairs[i]; // pick out the samples corresponding to these two classes for (unsigned long k = 0; k < all_samples.size(); ++k) { if (all_labels[k] == p.first) { samples.push_back(all_samples[k]); labels.push_back(+1); } else if (all_labels[k] == p.second) { samples.push_back(all_samples[k]); labels.push_back(-1); } } if (verbose) { auto_mutex lock(class_mutex); std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl; } any_trainer trainer; // now train a binary classifier using the samples we selected { auto_mutex lock(class_mutex); const typename binary_function_table::const_iterator itr = trainers.find(p); if (itr != trainers.end()) trainer = itr->second; else trainer = default_trainer; } any_decision_function<sample_type,scalar_type> binary_df = trainer.train(samples, labels); auto_mutex lock(class_mutex); dfs[p] = binary_df; } catch (std::exception& e) { auto_mutex lock(class_mutex); error_message = e.what(); } } mutable typename trained_function_type::binary_function_table dfs; mutex class_mutex; mutable std::string error_message; const std::vector<sample_type>& all_samples; const std::vector<label_type>& all_labels; const any_trainer& default_trainer; const binary_function_table& trainers; const bool verbose; const std::vector<unordered_pair<label_type> >& pairs; }; any_trainer default_trainer; binary_function_table trainers; bool verbose; unsigned long num_threads; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_ONE_VS_ONE_TRAiNER_Hh_