// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include "tester.h" #include <dlib/control.h> #include <vector> #include <sstream> #include <ctime> namespace { using namespace test; using namespace dlib; using namespace std; dlib::logger dlog("test.lspi"); template <bool have_prior> struct chain_model { typedef int state_type; typedef int action_type; // 0 is move left, 1 is move right const static bool force_last_weight_to_1 = have_prior; const static int num_states = 4; // not required in the model interface matrix<double,8,1> offset; chain_model() { offset = 2.048 , 2.56 , 2.048 , 3.2 , 2.56 , 4 , 3.2, 5 ; if (!have_prior) offset = 0; } unsigned long num_features( ) const { if (have_prior) return num_states*2 + 1; else return num_states*2; } action_type find_best_action ( const state_type& state, const matrix<double,0,1>& w ) const { if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1)) //if (w(state*2) >= w(state*2+1)) return 0; else return 1; } void get_features ( const state_type& state, const action_type& action, matrix<double,0,1>& feats ) const { feats.set_size(num_features()); feats = 0; feats(state*2 + action) = 1; if (have_prior) feats(num_features()-1) = offset(state*2+action); } }; void test_lspi_prior1() { print_spinner(); typedef process_sample<chain_model<true> > sample_type; std::vector<sample_type> samples; samples.push_back(sample_type(0,0,0,0)); samples.push_back(sample_type(0,1,1,0)); samples.push_back(sample_type(1,0,0,0)); samples.push_back(sample_type(1,1,2,0)); samples.push_back(sample_type(2,0,1,0)); samples.push_back(sample_type(2,1,3,0)); samples.push_back(sample_type(3,0,2,0)); samples.push_back(sample_type(3,1,3,1)); lspi<chain_model<true> > trainer; //trainer.be_verbose(); trainer.set_lambda(0); policy<chain_model<true> > pol = trainer.train(samples); dlog << LINFO << pol.get_weights(); matrix<double,0,1> w = pol.get_weights(); DLIB_TEST(pol.get_weights().size() == 9); DLIB_TEST(w(w.size()-1) == 1); w(w.size()-1) = 0; DLIB_TEST_MSG(length(w) < 1e-12, length(w)); dlog << LINFO << "action: " << pol(0); dlog << LINFO << "action: " << pol(1); dlog << LINFO << "action: " << pol(2); dlog << LINFO << "action: " << pol(3); DLIB_TEST(pol(0) == 1); DLIB_TEST(pol(1) == 1); DLIB_TEST(pol(2) == 1); DLIB_TEST(pol(3) == 1); } void test_lspi_prior2() { print_spinner(); typedef process_sample<chain_model<true> > sample_type; std::vector<sample_type> samples; samples.push_back(sample_type(0,0,0,0)); samples.push_back(sample_type(0,1,1,0)); samples.push_back(sample_type(1,0,0,0)); samples.push_back(sample_type(1,1,2,0)); samples.push_back(sample_type(2,0,1,0)); samples.push_back(sample_type(2,1,3,1)); samples.push_back(sample_type(3,0,2,0)); samples.push_back(sample_type(3,1,3,0)); lspi<chain_model<true> > trainer; //trainer.be_verbose(); trainer.set_lambda(0); policy<chain_model<true> > pol = trainer.train(samples); dlog << LINFO << "action: " << pol(0); dlog << LINFO << "action: " << pol(1); dlog << LINFO << "action: " << pol(2); dlog << LINFO << "action: " << pol(3); DLIB_TEST(pol(0) == 1); DLIB_TEST(pol(1) == 1); DLIB_TEST(pol(2) == 1); DLIB_TEST(pol(3) == 0); } void test_lspi_noprior1() { print_spinner(); typedef process_sample<chain_model<false> > sample_type; std::vector<sample_type> samples; samples.push_back(sample_type(0,0,0,0)); samples.push_back(sample_type(0,1,1,0)); samples.push_back(sample_type(1,0,0,0)); samples.push_back(sample_type(1,1,2,0)); samples.push_back(sample_type(2,0,1,0)); samples.push_back(sample_type(2,1,3,0)); samples.push_back(sample_type(3,0,2,0)); samples.push_back(sample_type(3,1,3,1)); lspi<chain_model<false> > trainer; //trainer.be_verbose(); trainer.set_lambda(0.01); policy<chain_model<false> > pol = trainer.train(samples); dlog << LINFO << pol.get_weights(); DLIB_TEST(pol.get_weights().size() == 8); dlog << LINFO << "action: " << pol(0); dlog << LINFO << "action: " << pol(1); dlog << LINFO << "action: " << pol(2); dlog << LINFO << "action: " << pol(3); DLIB_TEST(pol(0) == 1); DLIB_TEST(pol(1) == 1); DLIB_TEST(pol(2) == 1); DLIB_TEST(pol(3) == 1); } void test_lspi_noprior2() { print_spinner(); typedef process_sample<chain_model<false> > sample_type; std::vector<sample_type> samples; samples.push_back(sample_type(0,0,0,0)); samples.push_back(sample_type(0,1,1,0)); samples.push_back(sample_type(1,0,0,0)); samples.push_back(sample_type(1,1,2,1)); samples.push_back(sample_type(2,0,1,0)); samples.push_back(sample_type(2,1,3,0)); samples.push_back(sample_type(3,0,2,0)); samples.push_back(sample_type(3,1,3,0)); lspi<chain_model<false> > trainer; //trainer.be_verbose(); trainer.set_lambda(0.01); policy<chain_model<false> > pol = trainer.train(samples); dlog << LINFO << pol.get_weights(); DLIB_TEST(pol.get_weights().size() == 8); dlog << LINFO << "action: " << pol(0); dlog << LINFO << "action: " << pol(1); dlog << LINFO << "action: " << pol(2); dlog << LINFO << "action: " << pol(3); DLIB_TEST(pol(0) == 1); DLIB_TEST(pol(1) == 1); DLIB_TEST(pol(2) == 0); DLIB_TEST(pol(3) == 0); } class lspi_tester : public tester { public: lspi_tester ( ) : tester ( "test_lspi", // the command line argument name for this test "Run tests on the lspi object.", // the command line argument description 0 // the number of command line arguments for this test ) { } void perform_test ( ) { test_lspi_prior1(); test_lspi_prior2(); test_lspi_noprior1(); test_lspi_noprior2(); } }; lspi_tester a; }