// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include <sstream> #include "tester.h" #include <dlib/svm_threaded.h> #include <dlib/rand.h> namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.sequence_segmenter"); // ---------------------------------------------------------------------------------------- dlib::rand rnd; template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_> class unigram_extractor { public: const static bool use_BIO_model = use_BIO_model_; const static bool use_high_order_features = use_high_order_features_; const static bool allow_negative_weights = allow_negative_weights_; typedef std::vector<unsigned long> sequence_type; std::map<unsigned long, matrix<double,0,1> > feats; unigram_extractor() { matrix<double,0,1> v1, v2, v3; v1 = randm(num_features(), 1, rnd); v2 = randm(num_features(), 1, rnd); v3 = randm(num_features(), 1, rnd); v1(0) = 1; v2(1) = 1; v3(2) = 1; v1(3) = -1; v2(4) = -1; v3(5) = -1; for (unsigned long i = 0; i < num_features(); ++i) { if ( i < 3) feats[i] = v1; else if (i < 6) feats[i] = v2; else feats[i] = v3; } } unsigned long num_features() const { return 10; } unsigned long window_size() const { return 3; } template <typename feature_setter> void get_features ( feature_setter& set_feature, const sequence_type& x, unsigned long position ) const { const matrix<double,0,1>& m = feats.find(x[position])->second; for (unsigned long i = 0; i < num_features(); ++i) { set_feature(i, m(i)); } } }; template <bool use_BIO_model_, bool use_high_order_features_, bool neg> void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out ) { serialize(item.feats, out); } template <bool use_BIO_model_, bool use_high_order_features_, bool neg> void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in) { deserialize(item.feats, in); } // ---------------------------------------------------------------------------------------- void make_dataset ( std::vector<std::vector<unsigned long> >& samples, std::vector<std::vector<unsigned long> >& labels, unsigned long dataset_size ) { samples.clear(); labels.clear(); samples.resize(dataset_size); labels.resize(dataset_size); unigram_extractor<true,true,true> fe; dlib::rand rnd; for (unsigned long iter = 0; iter < dataset_size; ++iter) { samples[iter].resize(10); labels[iter].resize(10); for (unsigned long i = 0; i < samples[iter].size(); ++i) { samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features(); if (samples[iter][i] < 3) { labels[iter][i] = impl_ss::BEGIN; } else if (samples[iter][i] < 6) { labels[iter][i] = impl_ss::INSIDE; } else { labels[iter][i] = impl_ss::OUTSIDE; } if (i != 0) { // do rejection sampling to avoid impossible labels if (labels[iter][i] == impl_ss::INSIDE && labels[iter][i-1] == impl_ss::OUTSIDE) { --i; } } } } } // ---------------------------------------------------------------------------------------- void make_dataset2 ( std::vector<std::vector<unsigned long> >& samples, std::vector<std::vector<std::pair<unsigned long, unsigned long> > >& segments, unsigned long dataset_size ) { segments.clear(); std::vector<std::vector<unsigned long> > labels; make_dataset(samples, labels, dataset_size); segments.resize(samples.size()); // Convert from BIO tagging to the explicit segments representation. for (unsigned long k = 0; k < labels.size(); ++k) { for (unsigned long i = 0; i < labels[k].size(); ++i) { if (labels[k][i] == impl_ss::BEGIN) { const unsigned long begin = i; ++i; while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE) ++i; segments[k].push_back(std::make_pair(begin, i)); --i; } } } } // ---------------------------------------------------------------------------------------- template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights> void do_test() { dlog << LINFO << "use_BIO_model: "<< use_BIO_model; dlog << LINFO << "use_high_order_features: "<< use_high_order_features; dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights; std::vector<std::vector<unsigned long> > samples; std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments; make_dataset2( samples, segments, 100); print_spinner(); typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type; fe_type fe_temp; fe_type fe_temp2; structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2); trainer.set_c(5); trainer.set_num_threads(1); sequence_segmenter<fe_type> labeler = trainer.train(samples, segments); print_spinner(); const std::vector<std::pair<unsigned long, unsigned long> > predicted_labels = labeler(samples[1]); const std::vector<std::pair<unsigned long, unsigned long> > true_labels = segments[1]; /* for (unsigned long i = 0; i < predicted_labels.size(); ++i) cout << "["<<predicted_labels[i].first<<","<<predicted_labels[i].second<<") "; cout << endl; for (unsigned long i = 0; i < true_labels.size(); ++i) cout << "["<<true_labels[i].first<<","<<true_labels[i].second<<") "; cout << endl; */ DLIB_TEST(predicted_labels.size() > 0); DLIB_TEST(predicted_labels.size() == true_labels.size()); for (unsigned long i = 0; i < predicted_labels.size(); ++i) { DLIB_TEST(predicted_labels[i].first == true_labels[i].first); DLIB_TEST(predicted_labels[i].second == true_labels[i].second); } matrix<double> res; res = cross_validate_sequence_segmenter(trainer, samples, segments, 3); dlog << LINFO << "cv res: "<< res; DLIB_TEST(min(res) > 0.98); make_dataset2( samples, segments, 100); res = test_sequence_segmenter(labeler, samples, segments); dlog << LINFO << "test res: "<< res; DLIB_TEST(min(res) > 0.98); print_spinner(); ostringstream sout; serialize(labeler, sout); istringstream sin(sout.str()); sequence_segmenter<fe_type> labeler2; deserialize(labeler2, sin); res = test_sequence_segmenter(labeler2, samples, segments); dlog << LINFO << "test res2: "<< res; DLIB_TEST(min(res) > 0.98); long N; if (use_BIO_model) N = 3*3+3; else N = 5*5+5; const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N)); const double min_trans_weight = min(labeler2.get_weights()); dlog << LINFO << "min_normal_weight: " << min_normal_weight; dlog << LINFO << "min_trans_weight: " << min_trans_weight; if (allow_negative_weights) { DLIB_TEST(min_normal_weight < 0); DLIB_TEST(min_trans_weight < 0); } else { DLIB_TEST(min_normal_weight == 0); DLIB_TEST(min_trans_weight < 0); } } // ---------------------------------------------------------------------------------------- class unit_test_sequence_segmenter : public tester { public: unit_test_sequence_segmenter ( ) : tester ("test_sequence_segmenter", "Runs tests on the sequence segmenting code.") {} void perform_test ( ) { do_test<true,true,false>(); do_test<true,false,false>(); do_test<false,true,false>(); do_test<false,false,false>(); do_test<true,true,true>(); do_test<true,false,true>(); do_test<false,true,true>(); do_test<false,false,true>(); } } a; }