// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_ #define DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_ #include "structural_svm_sequence_labeling_problem_abstract.h" #include "../matrix.h" #include "sequence_labeler.h" #include #include "structural_svm_problem_threaded.h" // ---------------------------------------------------------------------------------------- namespace dlib { namespace fe_helpers { // ---------------------------------------------------------------------------------------- struct get_feats_functor { get_feats_functor(std::vector >& feats_) : feats(feats_) {} inline void operator() ( unsigned long feat_index, double feat_value ) { feats.push_back(std::make_pair(feat_index, feat_value)); } inline void operator() ( unsigned long feat_index ) { feats.push_back(std::make_pair(feat_index, 1)); } std::vector >& feats; }; // ---------------------------------------------------------------------------------------- template void get_feature_vector( std::vector >& feats, const feature_extractor& fe, const sequence_type& sequence, const matrix_exp& candidate_labeling, unsigned long position ) { get_feats_functor funct(feats); fe.get_features(funct, sequence,candidate_labeling, position); } } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class structural_svm_sequence_labeling_problem : noncopyable, public structural_svm_problem_threaded, std::vector > > { public: typedef matrix matrix_type; typedef std::vector > feature_vector_type; typedef typename feature_extractor::sequence_type sequence_type; structural_svm_sequence_labeling_problem( const std::vector& samples_, const std::vector >& labels_, const feature_extractor& fe_, unsigned long num_threads = 2 ) : structural_svm_problem_threaded(num_threads), samples(samples_), labels(labels_), fe(fe_) { // make sure requires clause is not broken DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true && contains_invalid_labeling(fe, samples, labels) == false, "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()" << "\n\t invalid inputs were given to this function" << "\n\t samples.size(): " << samples.size() << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) << "\n\t contains_invalid_labeling(fe,samples,labels): " << contains_invalid_labeling(fe,samples,labels) << "\n\t this: " << this ); #ifdef ENABLE_ASSERTS for (unsigned long i = 0; i < labels.size(); ++i) { for (unsigned long j = 0; j < labels[i].size(); ++j) { // make sure requires clause is not broken DLIB_ASSERT(labels[i][j] < fe.num_labels(), "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()" << "\n\t The given labels in labels are invalid." << "\n\t labels[i][j]: " << labels[i][j] << "\n\t fe.num_labels(): " << fe.num_labels() << "\n\t i: " << i << "\n\t j: " << j << "\n\t this: " << this ); } } #endif loss_values.assign(num_labels(), 1); } unsigned long num_labels ( ) const { return fe.num_labels(); } double get_loss ( unsigned long label ) const { // make sure requires clause is not broken DLIB_ASSERT(label < num_labels(), "\t void structural_svm_sequence_labeling_problem::get_loss()" << "\n\t invalid inputs were given to this function" << "\n\t label: " << label << "\n\t num_labels(): " << num_labels() << "\n\t this: " << this ); return loss_values[label]; } void set_loss ( unsigned long label, double value ) { // make sure requires clause is not broken DLIB_ASSERT(label < num_labels() && value >= 0, "\t void structural_svm_sequence_labeling_problem::set_loss()" << "\n\t invalid inputs were given to this function" << "\n\t label: " << label << "\n\t num_labels(): " << num_labels() << "\n\t value: " << value << "\n\t this: " << this ); loss_values[label] = value; } private: virtual long get_num_dimensions ( ) const { return fe.num_features(); } virtual long get_num_samples ( ) const { return samples.size(); } void get_joint_feature_vector ( const sequence_type& sample, const std::vector& label, feature_vector_type& psi ) const { psi.clear(); const int order = fe.order(); matrix candidate_labeling; for (unsigned long i = 0; i < sample.size(); ++i) { candidate_labeling = rowm(mat(label), range(i, std::max((int)i-order,0))); fe_helpers::get_feature_vector(psi,fe,sample,candidate_labeling, i); } } virtual void get_truth_joint_feature_vector ( long idx, feature_vector_type& psi ) const { get_joint_feature_vector(samples[idx], labels[idx], psi); } class map_prob { public: unsigned long order() const { return fe.order(); } unsigned long num_states() const { return fe.num_labels(); } map_prob( const sequence_type& sequence_, const std::vector& label_, const feature_extractor& fe_, const matrix& weights_, const std::vector& loss_values_ ) : sequence(sequence_), label(label_), fe(fe_), weights(weights_), loss_values(loss_values_) { } unsigned long number_of_nodes( ) const { return sequence.size(); } template < typename EXP > double factor_value ( unsigned long node_id, const matrix_exp& node_states ) const { if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id)) return -std::numeric_limits::infinity(); double loss = 0; if (node_states(0) != label[node_id]) loss = loss_values[label[node_id]]; return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss; } const sequence_type& sequence; const std::vector& label; const feature_extractor& fe; const matrix& weights; const std::vector& loss_values; }; virtual void separation_oracle ( const long idx, const matrix_type& current_solution, scalar_type& loss, feature_vector_type& psi ) const { std::vector y; find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution,loss_values), y); loss = 0; for (unsigned long i = 0; i < y.size(); ++i) { if (y[i] != labels[idx][i]) loss += loss_values[labels[idx][i]]; } get_joint_feature_vector(samples[idx], y, psi); } const std::vector& samples; const std::vector >& labels; const feature_extractor& fe; std::vector loss_values; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_STRUCTURAL_SVM_SEQUENCE_LaBELING_PROBLEM_Hh_