// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_ #define DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_ #include "structural_assignment_trainer_abstract.h" #include "../algs.h" #include "../optimization.h" #include "structural_svm_assignment_problem.h" #include "num_nonnegative_weights.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class structural_assignment_trainer { public: typedef typename feature_extractor::lhs_element lhs_element; typedef typename feature_extractor::rhs_element rhs_element; typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type; typedef std::vector<long> label_type; typedef assignment_function<feature_extractor> trained_function_type; structural_assignment_trainer ( ) { set_defaults(); } explicit structural_assignment_trainer ( const feature_extractor& fe_ ) : fe(fe_) { set_defaults(); } const feature_extractor& get_feature_extractor ( ) const { return fe; } void set_num_threads ( unsigned long num ) { num_threads = num; } unsigned long get_num_threads ( ) const { return num_threads; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void structural_assignment_trainer::set_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); eps = eps_; } double get_epsilon ( ) const { return eps; } void set_max_cache_size ( unsigned long max_size ) { max_cache_size = max_size; } unsigned long get_max_cache_size ( ) const { return max_cache_size; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_oca ( const oca& item ) { solver = item; } const oca get_oca ( ) const { return solver; } void set_c ( double C_ ) { // make sure requires clause is not broken DLIB_ASSERT(C_ > 0, "\t void structural_assignment_trainer::set_c()" << "\n\t C_ must be greater than 0" << "\n\t C_: " << C_ << "\n\t this: " << this ); C = C_; } double get_c ( ) const { return C; } bool forces_assignment( ) const { return force_assignment; } void set_forces_assignment ( bool new_value ) { force_assignment = new_value; } void set_loss_per_false_association ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss > 0, "\t void structural_assignment_trainer::set_loss_per_false_association(loss)" << "\n\t Invalid inputs were given to this function " << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_false_association = loss; } double get_loss_per_false_association ( ) const { return loss_per_false_association; } void set_loss_per_missed_association ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss > 0, "\t void structural_assignment_trainer::set_loss_per_missed_association(loss)" << "\n\t Invalid inputs were given to this function " << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_missed_association = loss; } double get_loss_per_missed_association ( ) const { return loss_per_missed_association; } bool forces_last_weight_to_1 ( ) const { return last_weight_1; } void force_last_weight_to_1 ( bool should_last_weight_be_1 ) { last_weight_1 = should_last_weight_be_1; } const assignment_function<feature_extractor> train ( const std::vector<sample_type>& samples, const std::vector<label_type>& labels ) const { // make sure requires clause is not broken #ifdef ENABLE_ASSERTS if (force_assignment) { DLIB_ASSERT(is_forced_assignment_problem(samples, labels), "\t assignment_function structural_assignment_trainer::train()" << "\n\t invalid inputs were given to this function" << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) ); } else { DLIB_ASSERT(is_assignment_problem(samples, labels), "\t assignment_function structural_assignment_trainer::train()" << "\n\t invalid inputs were given to this function" << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) ); } #endif structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads, loss_per_false_association, loss_per_missed_association); if (verbose) prob.be_verbose(); prob.set_c(C); prob.set_epsilon(eps); prob.set_max_cache_size(max_cache_size); matrix<double,0,1> weights; // Take the min here because we want to prevent the user from accidentally // forcing the bias term to be non-negative. const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe)); if (last_weight_1) solver(prob, weights, num_nonneg, fe.num_features()-1); else solver(prob, weights, num_nonneg); const double bias = weights(weights.size()-1); return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment); } private: bool force_assignment; double C; oca solver; double eps; bool verbose; unsigned long num_threads; unsigned long max_cache_size; double loss_per_false_association; double loss_per_missed_association; bool last_weight_1; void set_defaults () { force_assignment = false; C = 100; verbose = false; eps = 0.01; num_threads = 2; max_cache_size = 5; loss_per_false_association = 1; loss_per_missed_association = 1; last_weight_1 = false; } feature_extractor fe; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_