// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include #include #include #include "tester.h" #include #include typedef dlib::matrix lhs_element; typedef dlib::matrix rhs_element; namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.assignment_learning"); // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- struct feature_extractor_dense { typedef matrix feature_vector_type; typedef ::lhs_element lhs_element; typedef ::rhs_element rhs_element; unsigned long num_features() const { return 3; } void get_features ( const lhs_element& left, const rhs_element& right, feature_vector_type& feats ) const { feats = squared(left - right); } }; void serialize (const feature_extractor_dense& , std::ostream& ) {} void deserialize (feature_extractor_dense& , std::istream& ) {} // ---------------------------------------------------------------------------------------- struct feature_extractor_sparse { typedef std::vector > feature_vector_type; typedef ::lhs_element lhs_element; typedef ::rhs_element rhs_element; unsigned long num_features() const { return 3; } void get_features ( const lhs_element& left, const rhs_element& right, feature_vector_type& feats ) const { feats.clear(); feats.push_back(make_pair(0,squared(left-right)(0))); feats.push_back(make_pair(1,squared(left-right)(1))); feats.push_back(make_pair(2,squared(left-right)(2))); } }; void serialize (const feature_extractor_sparse& , std::ostream& ) {} void deserialize (feature_extractor_sparse& , std::istream& ) {} // ---------------------------------------------------------------------------------------- typedef std::pair, std::vector > sample_type; typedef std::vector label_type; // ---------------------------------------------------------------------------------------- void make_data ( std::vector& samples, std::vector& labels ) { lhs_element a, b, c, d; a = 1,0,0; b = 0,1,0; c = 0,0,1; d = 0,1,1; std::vector lhs; std::vector rhs; label_type label; lhs.push_back(a); lhs.push_back(b); lhs.push_back(c); rhs.push_back(b); rhs.push_back(a); rhs.push_back(c); label.push_back(1); label.push_back(0); label.push_back(2); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); lhs.push_back(a); lhs.push_back(b); lhs.push_back(c); rhs.push_back(c); rhs.push_back(b); rhs.push_back(a); rhs.push_back(d); label.push_back(2); label.push_back(1); label.push_back(0); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); lhs.push_back(a); lhs.push_back(b); lhs.push_back(c); rhs.push_back(c); rhs.push_back(a); rhs.push_back(d); label.push_back(1); label.push_back(-1); label.push_back(0); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); lhs.push_back(d); lhs.push_back(b); lhs.push_back(c); label.push_back(-1); label.push_back(-1); label.push_back(-1); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); } // ---------------------------------------------------------------------------------------- void make_data_force ( std::vector& samples, std::vector& labels ) { lhs_element a, b, c, d; a = 1,0,0; b = 0,1,0; c = 0,0,1; d = 0,1,1; std::vector lhs; std::vector rhs; label_type label; lhs.push_back(a); lhs.push_back(b); lhs.push_back(c); rhs.push_back(b); rhs.push_back(a); rhs.push_back(c); label.push_back(1); label.push_back(0); label.push_back(2); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); lhs.push_back(a); lhs.push_back(b); lhs.push_back(c); rhs.push_back(c); rhs.push_back(b); rhs.push_back(a); rhs.push_back(d); label.push_back(2); label.push_back(1); label.push_back(0); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); lhs.push_back(a); lhs.push_back(c); rhs.push_back(c); rhs.push_back(a); label.push_back(1); label.push_back(0); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); lhs.clear(); rhs.clear(); label.clear(); samples.push_back(make_pair(lhs,rhs)); labels.push_back(label); } // ---------------------------------------------------------------------------------------- template void test1(F make_data, bool force_assignment) { print_spinner(); std::vector samples; std::vector labels; make_data(samples, labels); make_data(samples, labels); make_data(samples, labels); randomize_samples(samples, labels); structural_assignment_trainer trainer; DLIB_TEST(trainer.forces_assignment() == false); DLIB_TEST(trainer.get_c() == 100); DLIB_TEST(trainer.get_num_threads() == 2); DLIB_TEST(trainer.get_max_cache_size() == 5); trainer.set_forces_assignment(force_assignment); trainer.set_num_threads(3); trainer.set_c(50); DLIB_TEST(trainer.get_c() == 50); DLIB_TEST(trainer.get_num_threads() == 3); DLIB_TEST(trainer.forces_assignment() == force_assignment); assignment_function ass = trainer.train(samples, labels); for (unsigned long i = 0; i < samples.size(); ++i) { std::vector out = ass(samples[i]); dlog << LINFO << "true labels: " << trans(mat(labels[i])); dlog << LINFO << "pred labels: " << trans(mat(out)); DLIB_TEST(trans(mat(labels[i])) == trans(mat(out))); } double accuracy; dlog << LINFO << "samples.size(): "<< samples.size(); accuracy = test_assignment_function(ass, samples, labels); dlog << LINFO << "accuracy: "<< accuracy; DLIB_TEST(accuracy == 1); accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3); dlog << LINFO << "cv accuracy: "<< accuracy; DLIB_TEST(accuracy == 1); ostringstream sout; serialize(ass, sout); istringstream sin(sout.str()); assignment_function ass2; deserialize(ass2, sin); DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment()); DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10); for (unsigned long i = 0; i < samples.size(); ++i) { std::vector out = ass2(samples[i]); dlog << LINFO << "true labels: " << trans(mat(labels[i])); dlog << LINFO << "pred labels: " << trans(mat(out)); DLIB_TEST(trans(mat(labels[i])) == trans(mat(out))); } } // ---------------------------------------------------------------------------------------- class test_assignment_learning : public tester { public: test_assignment_learning ( ) : tester ("test_assignment_learning", "Runs tests on the assignment learning code.") {} void perform_test ( ) { test1(make_data, false); test1(make_data, false); test1(make_data_force, false); test1(make_data_force, false); test1(make_data_force, true); test1(make_data_force, true); } } a; // ---------------------------------------------------------------------------------------- }