// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include "tester.h" #include #include #include "create_iris_datafile.h" #include #include #include namespace { using namespace test; using namespace dlib; using namespace std; dlib::logger dlog("test.svm_multiclass_trainer"); class test_svm_multiclass_trainer : public tester { /*! WHAT THIS OBJECT REPRESENTS This object represents a unit test. When it is constructed it adds itself into the testing framework. !*/ public: test_svm_multiclass_trainer ( ) : tester ( "test_svm_multiclass_trainer", // the command line argument name for this test "Run tests on the svm_multiclass_linear_trainer stuff.", // the command line argument description 0 // the number of command line arguments for this test ) { } void test_prior () { print_spinner(); typedef matrix sample_type; typedef linear_kernel kernel_type; std::vector samples; std::vector labels; for (int i = 0; i < 4; ++i) { if (i==2) ++i; for (int iter = 0; iter < 5; ++iter) { sample_type samp; samp = 0; samp(i) = 1; samples.push_back(samp); labels.push_back(i); } } svm_multiclass_linear_trainer trainer; multiclass_linear_decision_function df = trainer.train(samples, labels); //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl; //cout << df.weights << endl; //cout << df.b << endl; std::vector samples2; std::vector labels2; int i = 2; for (int iter = 0; iter < 5; ++iter) { sample_type samp; samp = 0; samp(i) = 1; samples2.push_back(samp); labels2.push_back(i); samples.push_back(samp); labels.push_back(i); } trainer.set_prior(df); trainer.set_c(0.1); df = trainer.train(samples2, labels2); matrix res = test_multiclass_decision_function(df, samples, labels); dlog << LINFO << "test: \n" << res; dlog << LINFO << df.weights; dlog << LINFO << df.b; DLIB_TEST((unsigned int)sum(diag(res))==samples.size()); } void test_prior_sparse () { print_spinner(); typedef std::map sample_type; typedef sparse_linear_kernel kernel_type; std::vector samples; std::vector labels; for (int i = 0; i < 4; ++i) { if (i==2) ++i; for (int iter = 0; iter < 5; ++iter) { sample_type samp; samp[i] = 1; samples.push_back(samp); labels.push_back(i); } } svm_multiclass_linear_trainer trainer; multiclass_linear_decision_function df = trainer.train(samples, labels); //cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl; //cout << df.weights << endl; //cout << df.b << endl; std::vector samples2; std::vector labels2; int i = 2; for (int iter = 0; iter < 5; ++iter) { sample_type samp; samp[i] = 1; samp[i+10] = 1; samples2.push_back(samp); labels2.push_back(i); samples.push_back(samp); labels.push_back(i); } trainer.set_prior(df); trainer.set_c(0.1); df = trainer.train(samples2, labels2); matrix res = test_multiclass_decision_function(df, samples, labels); dlog << LINFO << "test: \n" << res; dlog << LINFO << df.weights; dlog << LINFO << df.b; DLIB_TEST((unsigned int)sum(diag(res))==samples.size()); } template void run_test() { print_spinner(); typedef typename sample_type::value_type::second_type scalar_type; std::vector samples; std::vector labels; load_libsvm_formatted_data("iris.scale",samples, labels); DLIB_TEST(samples.size() == 150); DLIB_TEST(labels.size() == 150); typedef sparse_linear_kernel kernel_type; svm_multiclass_linear_trainer trainer; trainer.set_c(100); trainer.set_epsilon(0.000001); randomize_samples(samples, labels); matrix cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4); dlog << LINFO << "confusion matrix: \n" << cv; const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); dlog << LINFO << "cv accuracy: " << cv_accuracy; DLIB_TEST(cv_accuracy > 0.97); { print_spinner(); typedef matrix dsample_type; std::vector dsamples = sparse_to_dense(samples); DLIB_TEST(dsamples.size() == 150); typedef linear_kernel kernel_type; svm_multiclass_linear_trainer trainer; trainer.set_c(100); cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4); dlog << LINFO << "dense confusion matrix: \n" << cv; const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); dlog << LINFO << "dense cv accuracy: " << cv_accuracy; DLIB_TEST(cv_accuracy > 0.97); } } void perform_test ( ) { print_spinner(); create_iris_datafile(); run_test >(); run_test >(); run_test > >(); run_test > >(); test_prior(); test_prior_sparse(); } }; test_svm_multiclass_trainer a; }