// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ #define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ #include "structural_svm_distributed_abstract.h" #include "structural_svm_problem.h" #include "../bridge.h" #include "../smart_pointers.h" #include "../misc_api.h" #include "../statistics.h" #include "../threads.h" #include "../pipe.h" #include "../type_safe_union.h" #include #include namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { template struct oracle_response { typedef typename matrix_type::type scalar_type; matrix_type subgradient; scalar_type loss; long num; friend void swap (oracle_response& a, oracle_response& b) { a.subgradient.swap(b.subgradient); std::swap(a.loss, b.loss); std::swap(a.num, b.num); } friend void serialize (const oracle_response& item, std::ostream& out) { serialize(item.subgradient, out); dlib::serialize(item.loss, out); dlib::serialize(item.num, out); } friend void deserialize (oracle_response& item, std::istream& in) { deserialize(item.subgradient, in); dlib::deserialize(item.loss, in); dlib::deserialize(item.num, in); } }; // ---------------------------------------------------------------------------------------- template struct oracle_request { typedef typename matrix_type::type scalar_type; matrix_type current_solution; scalar_type saved_current_risk_gap; bool skip_cache; bool converged; friend void swap (oracle_request& a, oracle_request& b) { a.current_solution.swap(b.current_solution); std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap); std::swap(a.skip_cache, b.skip_cache); std::swap(a.converged, b.converged); } friend void serialize (const oracle_request& item, std::ostream& out) { serialize(item.current_solution, out); dlib::serialize(item.saved_current_risk_gap, out); dlib::serialize(item.skip_cache, out); dlib::serialize(item.converged, out); } friend void deserialize (oracle_request& item, std::istream& in) { deserialize(item.current_solution, in); dlib::deserialize(item.saved_current_risk_gap, in); dlib::deserialize(item.skip_cache, in); dlib::deserialize(item.converged, in); } }; } // ---------------------------------------------------------------------------------------- class svm_struct_processing_node : noncopyable { public: template < typename T, typename U > svm_struct_processing_node ( const structural_svm_problem& problem, unsigned short port, unsigned short num_threads ) { // make sure requires clause is not broken DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 && problem.get_num_dimensions() != 0, "\t svm_struct_processing_node()" << "\n\t Invalid arguments were given to this function" << "\n\t port: " << port << "\n\t problem.get_num_samples(): " << problem.get_num_samples() << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions() << "\n\t this: " << this ); the_problem.reset(new node_type(problem, port, num_threads)); } private: struct base { virtual ~base(){} }; template < typename matrix_type, typename feature_vector_type > class node_type : public base, threaded_object { public: typedef typename matrix_type::type scalar_type; node_type( const structural_svm_problem& prob, unsigned short port, unsigned long num_threads ) : in(3),out(3), problem(prob), tp(num_threads) { b.reconfigure(listen_on_port(port), receive(in), transmit(out)); start(); } ~node_type() { in.disable(); out.disable(); wait(); } private: void thread() { using namespace impl; tsu_in msg; tsu_out temp; timestamper ts; running_stats with_buffer_time; running_stats without_buffer_time; unsigned long num_iterations_executed = 0; while (in.dequeue(msg)) { // initialize the cache and compute psi_true. if (cache.size() == 0) { cache.resize(problem.get_num_samples()); for (unsigned long i = 0; i < cache.size(); ++i) cache[i].init(&problem,i); psi_true.set_size(problem.get_num_dimensions(),1); psi_true = 0; const unsigned long num = problem.get_num_samples(); feature_vector_type ftemp; for (unsigned long i = 0; i < num; ++i) { cache[i].get_truth_joint_feature_vector_cached(ftemp); subtract_from(psi_true, ftemp); } } if (msg.template contains() && msg.template get().is_connected) { temp = problem.get_num_dimensions(); out.enqueue(temp); } else if (msg.template contains >()) { ++num_iterations_executed; const oracle_request& req = msg.template get >(); oracle_response& data = temp.template get >(); data.subgradient = psi_true; data.loss = 0; data.num = problem.get_num_samples(); const uint64 start_time = ts.get_timestamp(); // pick fastest buffering strategy bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); // every 50 iterations we should try to flip the buffering scheme to see if // doing it the other way might be better. if ((num_iterations_executed%50) == 0) { buffer_subgradients_locally = !buffer_subgradients_locally; } binder b(*this, req, data, buffer_subgradients_locally); parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle); const uint64 stop_time = ts.get_timestamp(); if (buffer_subgradients_locally) with_buffer_time.add(stop_time-start_time); else without_buffer_time.add(stop_time-start_time); out.enqueue(temp); } } } struct binder { binder ( const node_type& self_, const impl::oracle_request& req_, impl::oracle_response& data_, bool buffer_subgradients_locally_ ) : self(self_), req(req_), data(data_), buffer_subgradients_locally(buffer_subgradients_locally_) {} void call_oracle ( long begin, long end ) { // If we are only going to call the separation oracle once then don't // run the slightly more complex for loop version of this code. Or if // we just don't want to run the complex buffering one. The code later // on decides if we should do the buffering based on how long it takes // to execute. We do this because, when the subgradient is really high // dimensional it can take a lot of time to add them together. So we // might want to avoid doing that. if (end-begin <= 1 || !buffer_subgradients_locally) { scalar_type loss; feature_vector_type ftemp; for (long i = begin; i < end; ++i) { self.cache[i].separation_oracle_cached(req.converged, req.skip_cache, req.saved_current_risk_gap, req.current_solution, loss, ftemp); auto_mutex lock(self.accum_mutex); data.loss += loss; add_to(data.subgradient, ftemp); } } else { scalar_type loss = 0; matrix_type faccum(data.subgradient.size(),1); faccum = 0; feature_vector_type ftemp; for (long i = begin; i < end; ++i) { scalar_type loss_temp; self.cache[i].separation_oracle_cached(req.converged, req.skip_cache, req.saved_current_risk_gap, req.current_solution, loss_temp, ftemp); loss += loss_temp; add_to(faccum, ftemp); } auto_mutex lock(self.accum_mutex); data.loss += loss; add_to(data.subgradient, faccum); } } const node_type& self; const impl::oracle_request& req; impl::oracle_response& data; bool buffer_subgradients_locally; }; typedef type_safe_union, bridge_status> tsu_in; typedef type_safe_union , long> tsu_out; pipe in; pipe out; bridge b; mutable matrix_type psi_true; const structural_svm_problem& problem; mutable std::vector > > cache; mutable thread_pool tp; mutex accum_mutex; }; scoped_ptr the_problem; }; // ---------------------------------------------------------------------------------------- class svm_struct_controller_node : noncopyable { public: svm_struct_controller_node ( ) : eps(0.001), cache_based_eps(std::numeric_limits::infinity()), verbose(false), C(1) {} double get_cache_based_epsilon ( ) const { return cache_based_eps; } void set_cache_based_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void svm_struct_controller_node::set_cache_based_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); cache_based_eps = eps_; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void svm_struct_controller_node::set_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); eps = eps_; } double get_epsilon ( ) const { return eps; } void be_verbose ( ) { verbose = true; } void be_quiet( ) { verbose = false; } void add_nuclear_norm_regularizer ( long first_dimension, long rows, long cols, double regularization_strength ) { // make sure requires clause is not broken DLIB_ASSERT(0 <= first_dimension && 0 <= rows && 0 <= cols && 0 < regularization_strength, "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()" << "\n\t Invalid arguments were given to this function." << "\n\t first_dimension: " << first_dimension << "\n\t rows: " << rows << "\n\t cols: " << cols << "\n\t regularization_strength: " << regularization_strength << "\n\t this: " << this ); impl::nuclear_norm_regularizer temp; temp.first_dimension = first_dimension; temp.nr = rows; temp.nc = cols; temp.regularization_strength = regularization_strength; nuclear_norm_regularizers.push_back(temp); } unsigned long num_nuclear_norm_regularizers ( ) const { return nuclear_norm_regularizers.size(); } void clear_nuclear_norm_regularizers ( ) { nuclear_norm_regularizers.clear(); } double get_c ( ) const { return C; } void set_c ( double C_ ) { // make sure requires clause is not broken DLIB_ASSERT(C_ > 0, "\t void svm_struct_controller_node::set_c()" << "\n\t C_ must be greater than 0" << "\n\t C_: " << C_ << "\n\t this: " << this ); C = C_; } void add_processing_node ( const network_address& addr ) { // make sure requires clause is not broken DLIB_ASSERT(addr.port != 0, "\t void svm_struct_controller_node::add_processing_node()" << "\n\t Invalid inputs were given to this function" << "\n\t addr.host_address: " << addr.host_address << "\n\t addr.port: " << addr.port << "\n\t this: " << this ); // check if this address is already registered for (unsigned long i = 0; i < nodes.size(); ++i) { if (nodes[i] == addr) { return; } } nodes.push_back(addr); } void add_processing_node ( const std::string& ip_or_hostname, unsigned short port ) { add_processing_node(network_address(ip_or_hostname,port)); } unsigned long get_num_processing_nodes ( ) const { return nodes.size(); } void remove_processing_nodes ( ) { nodes.clear(); } template double operator() ( const oca& solver, matrix_type& w ) const { // make sure requires clause is not broken DLIB_ASSERT(get_num_processing_nodes() != 0, "\t double svm_struct_controller_node::operator()" << "\n\t You must add some processing nodes before calling this function." << "\n\t this: " << this ); problem_type problem(nodes); problem.set_cache_based_epsilon(cache_based_eps); problem.set_epsilon(eps); if (verbose) problem.be_verbose(); problem.set_c(C); for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i) { problem.add_nuclear_norm_regularizer( nuclear_norm_regularizers[i].first_dimension, nuclear_norm_regularizers[i].nr, nuclear_norm_regularizers[i].nc, nuclear_norm_regularizers[i].regularization_strength); } return solver(problem, w); } class invalid_problem : public error { public: invalid_problem( const std::string& a ): error(a) {} }; private: template class problem_type : public structural_svm_problem { public: typedef typename matrix_type_::type scalar_type; typedef matrix_type_ matrix_type; problem_type ( const std::vector& nodes_ ) : nodes(nodes_), in(3), num_dims(0) { // initialize all the transmit pipes out_pipes.resize(nodes.size()); for (unsigned long i = 0; i < out_pipes.size(); ++i) { out_pipes[i].reset(new pipe(3)); } // make bridges that connect to all our remote processing nodes bridges.resize(nodes.size()); for (unsigned long i = 0; i< bridges.size(); ++i) { bridges[i].reset(new bridge(connect_to(nodes[i]), receive(in), transmit(*out_pipes[i]))); } // The remote processing nodes are supposed to all send the problem dimensionality // upon connection. So get that and make sure everyone agrees on what it's supposed to be. tsu_in temp; unsigned long responses = 0; bool seen_dim = false; while (responses < nodes.size()) { in.dequeue(temp); if (temp.template contains()) { ++responses; // if this new dimension doesn't match what we have seen previously if (seen_dim && num_dims != temp.template get()) { throw invalid_problem("remote hosts disagree on the number of dimensions!"); } seen_dim = true; num_dims = temp.template get(); } } } // These functions are just here because the structural_svm_problem requires // them, but since we are overloading get_risk() they are never called so they // don't matter. virtual long get_num_samples () const {return 0;} virtual void get_truth_joint_feature_vector ( long , matrix_type& ) const {} virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {} virtual long get_num_dimensions ( ) const { return num_dims; } virtual void get_risk ( matrix_type& w, scalar_type& risk, matrix_type& subgradient ) const { using namespace impl; subgradient.set_size(w.size(),1); subgradient = 0; // send out all the oracle requests tsu_out temp_out; for (unsigned long i = 0; i < out_pipes.size(); ++i) { temp_out.template get >().current_solution = w; temp_out.template get >().saved_current_risk_gap = this->saved_current_risk_gap; temp_out.template get >().skip_cache = this->skip_cache; temp_out.template get >().converged = this->converged; out_pipes[i]->enqueue(temp_out); } // collect all the oracle responses long num = 0; scalar_type total_loss = 0; tsu_in temp_in; unsigned long responses = 0; while (responses < out_pipes.size()) { in.dequeue(temp_in); if (temp_in.template contains >()) { ++responses; const oracle_response& data = temp_in.template get >(); subgradient += data.subgradient; total_loss += data.loss; num += data.num; } } subgradient /= num; total_loss /= num; risk = total_loss + dot(subgradient,w); if (this->nuclear_norm_regularizers.size() != 0) { matrix_type grad; double obj; this->compute_nuclear_norm_parts(w, grad, obj); risk += obj; subgradient += grad; } } std::vector nodes; typedef type_safe_union > tsu_out; typedef type_safe_union, long> tsu_in; std::vector > > out_pipes; mutable pipe in; std::vector > bridges; long num_dims; }; std::vector nodes; double eps; double cache_based_eps; bool verbose; double C; std::vector nuclear_norm_regularizers; }; // ---------------------------------------------------------------------------------------- } #endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_