// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include <sstream> #include <string> #include <cstdlib> #include <ctime> #include <dlib/svm.h> #include <dlib/matrix.h> #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; logger dlog("test.kmeans"); dlib::rand rnd; template <typename sample_type> void run_test( const std::vector<sample_type>& seed_centers ) { print_spinner(); sample_type samp; std::vector<sample_type> samples; for (unsigned long j = 0; j < seed_centers.size(); ++j) { for (int i = 0; i < 250; ++i) { samp = randm(seed_centers[0].size(),1,rnd) - 0.5; samples.push_back(samp + seed_centers[j]); } } randomize_samples(samples); { std::vector<sample_type> centers; pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>()); find_clusters_using_kmeans(samples, centers); DLIB_TEST(centers.size() == seed_centers.size()); std::vector<int> hits(centers.size(),0); for (unsigned long i = 0; i < samples.size(); ++i) { unsigned long best_idx = 0; double best_dist = 1e100; for (unsigned long j = 0; j < centers.size(); ++j) { if (length(samples[i] - centers[j]) < best_dist) { best_dist = length(samples[i] - centers[j]); best_idx = j; } } hits[best_idx]++; } for (unsigned long i = 0; i < hits.size(); ++i) { DLIB_TEST(hits[i] == 250); } } { std::vector<sample_type> centers; pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>()); find_clusters_using_angular_kmeans(samples, centers); DLIB_TEST(centers.size() == seed_centers.size()); std::vector<int> hits(centers.size(),0); for (unsigned long i = 0; i < samples.size(); ++i) { unsigned long best_idx = 0; double best_dist = 1e100; for (unsigned long j = 0; j < centers.size(); ++j) { if (length(samples[i] - centers[j]) < best_dist) { best_dist = length(samples[i] - centers[j]); best_idx = j; } } hits[best_idx]++; } for (unsigned long i = 0; i < hits.size(); ++i) { DLIB_TEST(hits[i] == 250); } } } class test_kmeans : public tester { public: test_kmeans ( ) : tester ("test_kmeans", "Runs tests on the find_clusters_using_kmeans() function.") {} void perform_test ( ) { { dlog << LINFO << "test dlib::vector<double,2>"; typedef dlib::vector<double,2> sample_type; std::vector<sample_type> seed_centers; seed_centers.push_back(sample_type(10,10)); seed_centers.push_back(sample_type(10,-10)); seed_centers.push_back(sample_type(-10,10)); seed_centers.push_back(sample_type(-10,-10)); run_test(seed_centers); } { dlog << LINFO << "test dlib::vector<double,2>"; typedef dlib::vector<float,2> sample_type; std::vector<sample_type> seed_centers; seed_centers.push_back(sample_type(10,10)); seed_centers.push_back(sample_type(10,-10)); seed_centers.push_back(sample_type(-10,10)); seed_centers.push_back(sample_type(-10,-10)); run_test(seed_centers); } { dlog << LINFO << "test dlib::matrix<double,3,1>"; typedef dlib::matrix<double,3,1> sample_type; std::vector<sample_type> seed_centers; sample_type samp; samp = 10,10,0; seed_centers.push_back(samp); samp = -10,10,1; seed_centers.push_back(samp); samp = -10,-10,2; seed_centers.push_back(samp); run_test(seed_centers); } } } a; }