// Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_AnY_TRAINER_H_ #define DLIB_AnY_TRAINER_H_ #include "any.h" #include "../smart_pointers.h" #include "any_decision_function.h" #include "any_trainer_abstract.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- template < typename sample_type_, typename scalar_type_ = double > class any_trainer { public: typedef sample_type_ sample_type; typedef scalar_type_ scalar_type; typedef default_memory_manager mem_manager_type; typedef any_decision_function trained_function_type; any_trainer() { } any_trainer ( const any_trainer& item ) { if (item.data) { item.data->copy_to(data); } } template any_trainer ( const T& item ) { typedef typename basic_type::type U; data.reset(new derived(item)); } void clear ( ) { data.reset(); } template bool contains ( ) const { typedef typename basic_type::type U; return dynamic_cast*>(data.get()) != 0; } bool is_empty( ) const { return data.get() == 0; } trained_function_type train ( const std::vector& samples, const std::vector& labels ) const { // make sure requires clause is not broken DLIB_ASSERT(is_empty() == false, "\t trained_function_type any_trainer::train()" << "\n\t You can't call train() on an empty any_trainer" << "\n\t this: " << this ); return data->train(samples, labels); } template T& cast_to( ) { typedef typename basic_type::type U; derived* d = dynamic_cast*>(data.get()); if (d == 0) { throw bad_any_cast(); } return d->item; } template const T& cast_to( ) const { typedef typename basic_type::type U; derived* d = dynamic_cast*>(data.get()); if (d == 0) { throw bad_any_cast(); } return d->item; } template T& get( ) { typedef typename basic_type::type U; derived* d = dynamic_cast*>(data.get()); if (d == 0) { d = new derived(); data.reset(d); } return d->item; } any_trainer& operator= ( const any_trainer& item ) { any_trainer(item).swap(*this); return *this; } void swap ( any_trainer& item ) { data.swap(item.data); } private: struct base { virtual ~base() {} virtual trained_function_type train ( const std::vector& samples, const std::vector& labels ) const = 0; virtual void copy_to ( scoped_ptr& dest ) const = 0; }; template struct derived : public base { T item; derived() {} derived(const T& val) : item(val) {} virtual void copy_to ( scoped_ptr& dest ) const { dest.reset(new derived(item)); } virtual trained_function_type train ( const std::vector& samples, const std::vector& labels ) const { return item.train(samples, labels); } }; scoped_ptr data; }; // ---------------------------------------------------------------------------------------- template < typename sample_type, typename scalar_type > inline void swap ( any_trainer& a, any_trainer& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- template T& any_cast(any_trainer& a) { return a.template cast_to(); } template const T& any_cast(const any_trainer& a) { return a.template cast_to(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_AnY_TRAINER_H_