// Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include <sstream> #include <string> #include <cstdlib> #include <ctime> #include <dlib/misc_api.h> #include <dlib/threads.h> #include <dlib/any.h> #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.thread_pool"); struct some_struct : noncopyable { float val; }; int global_var = 0; struct add_functor { add_functor() { var = 1;} add_functor(int v):var(v) {} template <typename T, typename U, typename V> void operator()(T a, U b, V& res) { dlib::sleep(20); res = a + b; } void set_global_var() { global_var = 9; } void set_global_var_const() const { global_var = 9; } void set_global_var_arg1(int val) { global_var = val; } void set_global_var_const_arg1(int val) const { global_var = val; } void set_global_var_arg2(int val, int val2) { global_var = val+val2; } void set_global_var_const_arg2(int val, int val2) const { global_var = val+val2; } void operator()() { global_var = 9; } // use an any just so that if this object goes out of scope // then var will get all messed up. any var; void operator()(int& a) { dlib::sleep(100); a = var.get<int>(); } void operator()(int& a, int& b) { dlib::sleep(100); a = var.get<int>(); b = 2; } void operator()(int& a, int& b, int& c) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; } void operator()(int& a, int& b, int& c, int& d) { dlib::sleep(100); a = var.get<int>(); b = 2; c = 3; d = 4; } }; void set_global_var() { global_var = 9; } void gset_struct_to_zero (some_struct& a) { a.val = 0; } void gset_to_zero (int& a) { a = 0; } void gincrement (int& a) { ++a; } void gadd (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; } void gadd1(int& a, int& res) { res += a; } void gadd2 (int c, int a, const int& b, int& res) { dlib::sleep(20); res = a + b + c; } class thread_pool_tester : public tester { public: thread_pool_tester ( ) : tester ("test_thread_pool", "Runs tests on the thread_pool component.") {} void perform_test ( ) { add_functor f; for (int num_threads= 0; num_threads < 4; ++num_threads) { dlib::future<int> a, b, c, res, d; thread_pool tp(num_threads); print_spinner(); dlib::future<some_struct> obj; for (int i = 0; i < 4; ++i) { a = 1; b = 2; c = 3; res = 4; DLIB_TEST(a==a); DLIB_TEST(a!=b); DLIB_TEST(a==1); tp.add_task(gset_to_zero, a); tp.add_task(gset_to_zero, b); tp.add_task(*this, &thread_pool_tester::set_to_zero, c); tp.add_task(gset_to_zero, res); DLIB_TEST(a == 0); DLIB_TEST(b == 0); DLIB_TEST(c == 0); DLIB_TEST(res == 0); tp.add_task(gincrement, a); tp.add_task(*this, &thread_pool_tester::increment, b); tp.add_task(*this, &thread_pool_tester::increment, c); tp.add_task(gincrement, res); DLIB_TEST(a == 1); DLIB_TEST(b == 1); DLIB_TEST(c == 1); DLIB_TEST(res == 1); tp.add_task(&gincrement, a); tp.add_task(*this, &thread_pool_tester::increment, b); tp.add_task(*this, &thread_pool_tester::increment, c); tp.add_task(&gincrement, res); tp.add_task(gincrement, a); tp.add_task(*this, &thread_pool_tester::increment, b); tp.add_task(*this, &thread_pool_tester::increment, c); tp.add_task(gincrement, res); DLIB_TEST(a == 3); DLIB_TEST(b == 3); DLIB_TEST(c == 3); DLIB_TEST(res == 3); tp.add_task(*this, &thread_pool_tester::increment, c); tp.add_task(gincrement, res); DLIB_TEST(c == 4); DLIB_TEST(res == 4); tp.add_task(gadd, a, b, res); DLIB_TEST(res == a+b); DLIB_TEST(res == 6); a = 3; b = 4; res = 99; DLIB_TEST(res == 99); tp.add_task(*this, &thread_pool_tester::add, a, b, res); DLIB_TEST(res == a+b); DLIB_TEST(res == 7); a = 1; b = 2; c = 3; res = 88; DLIB_TEST(res == 88); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(c == 3); tp.add_task(gadd2, a, b, c, res); DLIB_TEST(res == 6); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(c == 3); a = 1; b = 2; c = 3; res = 88; DLIB_TEST(res == 88); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(c == 3); tp.add_task(*this, &thread_pool_tester::add2, a, b, c, res); DLIB_TEST(res == 6); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(c == 3); a = 1; b = 2; c = 3; res = 88; tp.add_task(gadd1, a, b); DLIB_TEST(a == 1); DLIB_TEST(b == 3); a = 2; tp.add_task(*this, &thread_pool_tester::add1, a, b); DLIB_TEST(a == 2); DLIB_TEST(b == 5); val = 4; uint64 id = tp.add_task(*this, &thread_pool_tester::zero_val); tp.wait_for_task(id); DLIB_TEST(val == 0); id = tp.add_task(*this, &thread_pool_tester::accum2, 1,2); tp.wait_for_all_tasks(); DLIB_TEST(val == 3); id = tp.add_task(*this, &thread_pool_tester::accum1, 3); tp.wait_for_task(id); DLIB_TEST(val == 6); obj.get().val = 8; DLIB_TEST(obj.get().val == 8); tp.add_task(gset_struct_to_zero, obj); DLIB_TEST(obj.get().val == 0); obj.get().val = 8; DLIB_TEST(obj.get().val == 8); tp.add_task(*this,&thread_pool_tester::set_struct_to_zero, obj); DLIB_TEST(obj.get().val == 0); a = 1; b = 2; res = 0; tp.add_task(f, a, b, res); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(res == 3); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task(&set_global_var); tp.wait_for_task(id); DLIB_TEST(global_var == 9); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task(f); tp.wait_for_task(id); DLIB_TEST(global_var == 9); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var); tp.wait_for_task(id); DLIB_TEST(global_var == 9); global_var = 0; a = 4; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var_arg1, a); tp.wait_for_task(id); DLIB_TEST(global_var == 4); global_var = 0; a = 4; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var_arg1, a); tp.wait_for_task(id); DLIB_TEST(global_var == 4); global_var = 0; a = 4; b = 3; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var_arg2, a, b); tp.wait_for_task(id); DLIB_TEST(global_var == 7); global_var = 0; a = 4; b = 3; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var_arg2, a, b); tp.wait_for_task(id); DLIB_TEST(global_var == 7); global_var = 0; a = 4; b = 3; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var_const_arg2, a, b); tp.wait_for_task(id); DLIB_TEST(global_var == 7); global_var = 0; a = 4; b = 3; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg2, a, b); tp.wait_for_task(id); DLIB_TEST(global_var == 7); global_var = 0; a = 4; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var_const_arg1, a); tp.wait_for_task(id); DLIB_TEST(global_var == 4); global_var = 0; a = 4; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var_const_arg1, a); tp.wait_for_task(id); DLIB_TEST(global_var == 4); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var); tp.wait_for_task(id); DLIB_TEST(global_var == 9); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task(f, &add_functor::set_global_var_const); tp.wait_for_task(id); DLIB_TEST(global_var == 9); global_var = 0; DLIB_TEST(global_var == 0); id = tp.add_task_by_value(f, &add_functor::set_global_var_const); tp.wait_for_task(id); DLIB_TEST(global_var == 9); } // add this task just to to perterb the thread pool before it goes out of scope tp.add_task(f, a, b, res); for (int k = 0; k < 3; ++k) { print_spinner(); global_var = 0; tp.add_task_by_value(add_functor()); tp.wait_for_all_tasks(); DLIB_TEST(global_var == 9); a = 0; b = 0; c = 0; d = 0; tp.add_task_by_value(add_functor(), a); DLIB_TEST(a == 1); a = 0; b = 0; c = 0; d = 0; tp.add_task_by_value(add_functor(8), a, b); DLIB_TEST(a == 8); DLIB_TEST(b == 2); a = 0; b = 0; c = 0; d = 0; tp.add_task_by_value(add_functor(), a, b, c); DLIB_TEST(a == 1); DLIB_TEST(b == 2); DLIB_TEST(c == 3); a = 0; b = 0; c = 0; d = 0; tp.add_task_by_value(add_functor(5), a, b, c, d); DLIB_TEST(a == 5); DLIB_TEST(b == 2); DLIB_TEST(c == 3); DLIB_TEST(d == 4); } tp.wait_for_all_tasks(); // make sure exception propagation from tasks works correctly. auto f_throws = []() { throw dlib::error("test exception");}; bool got_exception = false; try { tp.add_task_by_value(f_throws); tp.wait_for_all_tasks(); } catch(dlib::error& e) { DLIB_TEST(e.info == "test exception"); got_exception = true; } DLIB_TEST(got_exception); dlib::future<int> aa; auto f_throws2 = [](int& a) { a = 1; throw dlib::error("test exception");}; got_exception = false; try { tp.add_task(f_throws2, aa); aa.get(); } catch(dlib::error& e) { DLIB_TEST(e.info == "test exception"); got_exception = true; } DLIB_TEST(got_exception); } } long val; void accum1(long a) { val += a; } void accum2(long a, long b) { val += a + b; } void zero_val() { dlib::sleep(20); val = 0; } void set_struct_to_zero (some_struct& a) { a.val = 0; } void set_to_zero (int& a) { dlib::sleep(20); a = 0; } void increment (int& a) const { dlib::sleep(20); ++a; } void add (int a, const int& b, int& res) { dlib::sleep(20); res = a + b; } void add1(int& a, int& res) const { res += a; } void add2 (int c, int a, const int& b, int& res) { res = a + b + c; } } a; }