// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_ #define DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_ #include "structural_sequence_segmentation_trainer_abstract.h" #include "structural_sequence_labeling_trainer.h" #include "sequence_segmenter.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class structural_sequence_segmentation_trainer { public: typedef typename feature_extractor::sequence_type sample_sequence_type; typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type; typedef sequence_segmenter<feature_extractor> trained_function_type; explicit structural_sequence_segmentation_trainer ( const feature_extractor& fe_ ) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_)) { loss_per_missed_segment = 1; loss_per_false_alarm = 1; } structural_sequence_segmentation_trainer ( ) { loss_per_missed_segment = 1; loss_per_false_alarm = 1; } const feature_extractor& get_feature_extractor ( ) const { return trainer.get_feature_extractor().fe; } void set_num_threads ( unsigned long num ) { trainer.set_num_threads(num); } unsigned long get_num_threads ( ) const { return trainer.get_num_threads(); } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void structural_sequence_segmentation_trainer::set_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); trainer.set_epsilon(eps_); } double get_epsilon ( ) const { return trainer.get_epsilon(); } unsigned long get_max_iterations ( ) const { return trainer.get_max_iterations(); } void set_max_iterations ( unsigned long max_iter ) { trainer.set_max_iterations(max_iter); } void set_max_cache_size ( unsigned long max_size ) { trainer.set_max_cache_size(max_size); } unsigned long get_max_cache_size ( ) const { return trainer.get_max_cache_size(); } void be_verbose ( ) { trainer.be_verbose(); } void be_quiet ( ) { trainer.be_quiet(); } void set_oca ( const oca& item ) { trainer.set_oca(item); } const oca get_oca ( ) const { return trainer.get_oca(); } void set_c ( double C_ ) { // make sure requires clause is not broken DLIB_ASSERT(C_ > 0, "\t void structural_sequence_segmentation_trainer::set_c()" << "\n\t C_ must be greater than 0" << "\n\t C_: " << C_ << "\n\t this: " << this ); trainer.set_c(C_); } double get_c ( ) const { return trainer.get_c(); } void set_loss_per_missed_segment ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss >= 0, "\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)" << "\n\t invalid inputs were given to this function" << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_missed_segment = loss; if (feature_extractor::use_BIO_model) { trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment); trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment); } else { trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment); trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment); trainer.set_loss(impl_ss::LAST, loss_per_missed_segment); trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment); } } double get_loss_per_missed_segment ( ) const { return loss_per_missed_segment; } void set_loss_per_false_alarm ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss >= 0, "\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)" << "\n\t invalid inputs were given to this function" << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_false_alarm = loss; trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm); } double get_loss_per_false_alarm ( ) const { return loss_per_false_alarm; } const sequence_segmenter<feature_extractor> train( const std::vector<sample_sequence_type>& x, const std::vector<segmented_sequence_type>& y ) const { // make sure requires clause is not broken DLIB_ASSERT(is_sequence_segmentation_problem(x,y) == true, "\t sequence_segmenter structural_sequence_segmentation_trainer::train(x,y)" << "\n\t invalid inputs were given to this function" << "\n\t x.size(): " << x.size() << "\n\t is_sequence_segmentation_problem(x,y): " << is_sequence_segmentation_problem(x,y) << "\n\t this: " << this ); std::vector<std::vector<unsigned long> > labels(y.size()); if (feature_extractor::use_BIO_model) { // convert y into tagged BIO labels for (unsigned long i = 0; i < labels.size(); ++i) { labels[i].resize(x[i].size(), impl_ss::OUTSIDE); for (unsigned long j = 0; j < y[i].size(); ++j) { const unsigned long begin = y[i][j].first; const unsigned long end = y[i][j].second; if (begin != end) { labels[i][begin] = impl_ss::BEGIN; for (unsigned long k = begin+1; k < end; ++k) labels[i][k] = impl_ss::INSIDE; } } } } else { // convert y into tagged BILOU labels for (unsigned long i = 0; i < labels.size(); ++i) { labels[i].resize(x[i].size(), impl_ss::OUTSIDE); for (unsigned long j = 0; j < y[i].size(); ++j) { const unsigned long begin = y[i][j].first; const unsigned long end = y[i][j].second; if (begin != end) { if (begin+1==end) { labels[i][begin] = impl_ss::UNIT; } else { labels[i][begin] = impl_ss::BEGIN; for (unsigned long k = begin+1; k+1 < end; ++k) labels[i][k] = impl_ss::INSIDE; labels[i][end-1] = impl_ss::LAST; } } } } } sequence_labeler<impl_ss::feature_extractor<feature_extractor> > temp; temp = trainer.train(x, labels); return sequence_segmenter<feature_extractor>(temp.get_weights(), trainer.get_feature_extractor().fe); } private: structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer; double loss_per_missed_segment; double loss_per_false_alarm; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_STRUCTURAL_SEQUENCE_sEGMENTATION_TRAINER_Hh_