// Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_RAND_KERNEl_1_ #define DLIB_RAND_KERNEl_1_ #include #include "../algs.h" #include "rand_kernel_abstract.h" #include "mersenne_twister.h" #include "../is_kind.h" #include #include "../serialize.h" namespace dlib { class rand { /*! INITIAL VALUE - seed == "" CONVENTION - the random numbers come from the boost mersenne_twister code - get_seed() == seed !*/ public: // These typedefs are here for backwards compatibility with older versions of dlib. typedef rand kernel_1a; typedef rand float_1a; rand( ) { init(); } rand ( const std::string& seed_value ) { init(); set_seed(seed_value); } virtual ~rand( ) {} void clear( ) { mt.seed(); seed.clear(); has_gaussian = false; next_gaussian = 0; // prime the generator a bit for (int i = 0; i < 10000; ++i) mt(); } const std::string& get_seed ( ) { return seed; } void set_seed ( const std::string& value ) { seed = value; // make sure we do the seeding so that using a seed of "" gives the same // state as calling this->clear() if (value.size() != 0) { uint32 s = 0; for (std::string::size_type i = 0; i < seed.size(); ++i) { s = (s*37) + static_cast(seed[i]); } mt.seed(s); } else { mt.seed(); } // prime the generator a bit for (int i = 0; i < 10000; ++i) mt(); has_gaussian = false; next_gaussian = 0; } unsigned char get_random_8bit_number ( ) { return static_cast(mt()); } uint16 get_random_16bit_number ( ) { return static_cast(mt()); } inline uint32 get_random_32bit_number ( ) { return mt(); } inline uint64 get_random_64bit_number ( ) { const uint64 a = get_random_32bit_number(); const uint64 b = get_random_32bit_number(); return (a<<32)|b; } double get_random_double ( ) { uint32 temp; temp = rand::get_random_32bit_number(); temp &= 0xFFFFFF; double val = static_cast(temp); val *= 0x1000000; temp = rand::get_random_32bit_number(); temp &= 0xFFFFFF; val += temp; val /= max_val; if (val < 1.0) { return val; } else { // return a value slightly less than 1.0 return 1.0 - std::numeric_limits::epsilon(); } } float get_random_float ( ) { uint32 temp; temp = rand::get_random_32bit_number(); temp &= 0xFFFFFF; const float scale = 1.0/0x1000000; const float val = static_cast(temp)*scale; if (val < 1.0f) { return val; } else { // return a value slightly less than 1.0 return 1.0f - std::numeric_limits::epsilon(); } } double get_random_gaussian ( ) { if (has_gaussian) { has_gaussian = false; return next_gaussian; } double x1, x2, w; const double rndmax = std::numeric_limits::max(); // Generate a pair of Gaussian random numbers using the Box-Muller transformation. do { const double rnd1 = get_random_32bit_number()/rndmax; const double rnd2 = get_random_32bit_number()/rndmax; x1 = 2.0 * rnd1 - 1.0; x2 = 2.0 * rnd2 - 1.0; w = x1 * x1 + x2 * x2; } while ( w >= 1.0 ); w = std::sqrt( (-2.0 * std::log( w ) ) / w ); next_gaussian = x2 * w; has_gaussian = true; return x1 * w; } void swap ( rand& item ) { exchange(mt,item.mt); exchange(seed, item.seed); exchange(has_gaussian, item.has_gaussian); exchange(next_gaussian, item.next_gaussian); } friend void serialize( const rand& item, std::ostream& out ); friend void deserialize( rand& item, std::istream& in ); private: void init() { // prime the generator a bit for (int i = 0; i < 10000; ++i) mt(); max_val = 0xFFFFFF; max_val *= 0x1000000; max_val += 0xFFFFFF; max_val += 0.01; has_gaussian = false; next_gaussian = 0; } mt19937 mt; std::string seed; double max_val; bool has_gaussian; double next_gaussian; }; inline void swap ( rand& a, rand& b ) { a.swap(b); } template <> struct is_rand { static const bool value = true; }; inline void serialize( const rand& item, std::ostream& out ) { int version = 1; serialize(version, out); serialize(item.mt, out); serialize(item.seed, out); serialize(item.has_gaussian, out); serialize(item.next_gaussian, out); } inline void deserialize( rand& item, std::istream& in ) { int version; deserialize(version, in); if (version != 1) throw serialization_error("Error deserializing object of type rand: unexpected version."); deserialize(item.mt, in); deserialize(item.seed, in); deserialize(item.has_gaussian, in); deserialize(item.next_gaussian, in); } } #endif // DLIB_RAND_KERNEl_1_