// Copyright (C) 2007 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MLp_KERNEL_1_ #define DLIB_MLp_KERNEL_1_ #include "../algs.h" #include "../serialize.h" #include "../matrix.h" #include "../rand.h" #include "mlp_kernel_abstract.h" #include <ctime> #include <sstream> namespace dlib { class mlp_kernel_1 : noncopyable { /*! INITIAL VALUE The network is initially initialized with random weights CONVENTION - input_layer_nodes() == input_nodes - first_hidden_layer_nodes() == first_hidden_nodes - second_hidden_layer_nodes() == second_hidden_nodes - output_layer_nodes() == output_nodes - get_alpha == alpha - get_momentum() == momentum - if (second_hidden_nodes == 0) then - for all i and j: - w1(i,j) == the weight on the link from node i in the first hidden layer to input node j - w3(i,j) == the weight on the link from node i in the output layer to first hidden layer node j - for all i and j: - w1m == the momentum terms for w1 from the previous update - w3m == the momentum terms for w3 from the previous update - else - for all i and j: - w1(i,j) == the weight on the link from node i in the first hidden layer to input node j - w2(i,j) == the weight on the link from node i in the second hidden layer to first hidden layer node j - w3(i,j) == the weight on the link from node i in the output layer to second hidden layer node j - for all i and j: - w1m == the momentum terms for w1 from the previous update - w2m == the momentum terms for w2 from the previous update - w3m == the momentum terms for w3 from the previous update !*/ public: mlp_kernel_1 ( long nodes_in_input_layer, long nodes_in_first_hidden_layer, long nodes_in_second_hidden_layer = 0, long nodes_in_output_layer = 1, double alpha_ = 0.1, double momentum_ = 0.8 ) : input_nodes(nodes_in_input_layer), first_hidden_nodes(nodes_in_first_hidden_layer), second_hidden_nodes(nodes_in_second_hidden_layer), output_nodes(nodes_in_output_layer), alpha(alpha_), momentum(momentum_) { // seed the random number generator std::ostringstream sout; sout << time(0); rand_nums.set_seed(sout.str()); w1.set_size(first_hidden_nodes+1, input_nodes+1); w1m.set_size(first_hidden_nodes+1, input_nodes+1); z.set_size(input_nodes+1,1); if (second_hidden_nodes != 0) { w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1); w3.set_size(output_nodes, second_hidden_nodes+1); w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1); w3m.set_size(output_nodes, second_hidden_nodes+1); } else { w3.set_size(output_nodes, first_hidden_nodes+1); w3m.set_size(output_nodes, first_hidden_nodes+1); } reset(); } virtual ~mlp_kernel_1 ( ) {} void reset ( ) { // randomize the weights for the first layer for (long r = 0; r < w1.nr(); ++r) for (long c = 0; c < w1.nc(); ++c) w1(r,c) = rand_nums.get_random_double(); // randomize the weights for the second layer for (long r = 0; r < w2.nr(); ++r) for (long c = 0; c < w2.nc(); ++c) w2(r,c) = rand_nums.get_random_double(); // randomize the weights for the third layer for (long r = 0; r < w3.nr(); ++r) for (long c = 0; c < w3.nc(); ++c) w3(r,c) = rand_nums.get_random_double(); // zero all the momentum terms set_all_elements(w1m,0); set_all_elements(w2m,0); set_all_elements(w3m,0); } long input_layer_nodes ( ) const { return input_nodes; } long first_hidden_layer_nodes ( ) const { return first_hidden_nodes; } long second_hidden_layer_nodes ( ) const { return second_hidden_nodes; } long output_layer_nodes ( ) const { return output_nodes; } double get_alpha ( ) const { return alpha; } double get_momentum ( ) const { return momentum; } template <typename EXP> const matrix<double> operator() ( const matrix_exp<EXP>& in ) const { for (long i = 0; i < in.nr(); ++i) z(i) = in(i); // insert the bias z(z.nr()-1) = -1; tmp1 = sigmoid(w1*z); // insert the bias tmp1(tmp1.nr()-1) = -1; if (second_hidden_nodes == 0) { return sigmoid(w3*tmp1); } else { tmp2 = sigmoid(w2*tmp1); // insert the bias tmp2(tmp2.nr()-1) = -1; return sigmoid(w3*tmp2); } } template <typename EXP1, typename EXP2> void train ( const matrix_exp<EXP1>& example_in, const matrix_exp<EXP2>& example_out ) { for (long i = 0; i < example_in.nr(); ++i) z(i) = example_in(i); // insert the bias z(z.nr()-1) = -1; tmp1 = sigmoid(w1*z); // insert the bias tmp1(tmp1.nr()-1) = -1; if (second_hidden_nodes == 0) { o = sigmoid(w3*tmp1); // now compute the errors and propagate them backwards though the network e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 ); // compute the new weight updates w3m = alpha * e3*trans(tmp1) + w3m*momentum; w1m = alpha * e1*trans(z) + w1m*momentum; // now update the weights w1 += w1m; w3 += w3m; } else { tmp2 = sigmoid(w2*tmp1); // insert the bias tmp2(tmp2.nr()-1) = -1; o = sigmoid(w3*tmp2); // now compute the errors and propagate them backwards though the network e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 ); e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 ); // compute the new weight updates w3m = alpha * e3*trans(tmp2) + w3m*momentum; w2m = alpha * e2*trans(tmp1) + w2m*momentum; w1m = alpha * e1*trans(z) + w1m*momentum; // now update the weights w1 += w1m; w2 += w2m; w3 += w3m; } } template <typename EXP> void train ( const matrix_exp<EXP>& example_in, double example_out ) { matrix<double,1,1> e_out; e_out(0) = example_out; train(example_in,e_out); } double get_average_change ( ) const { // sum up all the weight changes double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m)); // divide by the number of weights delta /= w1m.nr()*w1m.nc() + w2m.nr()*w2m.nc() + w3m.nr()*w3m.nc(); return delta; } void swap ( mlp_kernel_1& item ) { exchange(input_nodes, item.input_nodes); exchange(first_hidden_nodes, item.first_hidden_nodes); exchange(second_hidden_nodes, item.second_hidden_nodes); exchange(output_nodes, item.output_nodes); exchange(alpha, item.alpha); exchange(momentum, item.momentum); w1.swap(item.w1); w2.swap(item.w2); w3.swap(item.w3); w1m.swap(item.w1m); w2m.swap(item.w2m); w3m.swap(item.w3m); // even swap the temporary matrices because this may ultimately result in // fewer calls to new and delete. e1.swap(item.e1); e2.swap(item.e2); e3.swap(item.e3); z.swap(item.z); tmp1.swap(item.tmp1); tmp2.swap(item.tmp2); o.swap(item.o); } friend void serialize ( const mlp_kernel_1& item, std::ostream& out ); friend void deserialize ( mlp_kernel_1& item, std::istream& in ); private: long input_nodes; long first_hidden_nodes; long second_hidden_nodes; long output_nodes; double alpha; double momentum; matrix<double> w1; matrix<double> w2; matrix<double> w3; matrix<double> w1m; matrix<double> w2m; matrix<double> w3m; rand rand_nums; // temporary storage mutable matrix<double> e1, e2, e3; mutable matrix<double> z, tmp1, tmp2, o; }; inline void swap ( mlp_kernel_1& a, mlp_kernel_1& b ) { a.swap(b); } // ---------------------------------------------------------------------------------------- inline void serialize ( const mlp_kernel_1& item, std::ostream& out ) { try { serialize(item.input_nodes, out); serialize(item.first_hidden_nodes, out); serialize(item.second_hidden_nodes, out); serialize(item.output_nodes, out); serialize(item.alpha, out); serialize(item.momentum, out); serialize(item.w1, out); serialize(item.w2, out); serialize(item.w3, out); serialize(item.w1m, out); serialize(item.w2m, out); serialize(item.w3m, out); } catch (serialization_error& e) { throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1"); } } inline void deserialize ( mlp_kernel_1& item, std::istream& in ) { try { deserialize(item.input_nodes, in); deserialize(item.first_hidden_nodes, in); deserialize(item.second_hidden_nodes, in); deserialize(item.output_nodes, in); deserialize(item.alpha, in); deserialize(item.momentum, in); deserialize(item.w1, in); deserialize(item.w2, in); deserialize(item.w3, in); deserialize(item.w1m, in); deserialize(item.w2m, in); deserialize(item.w3m, in); item.z.set_size(item.input_nodes+1,1); } catch (serialization_error& e) { // give item a reasonable value since the deserialization failed mlp_kernel_1(1,1).swap(item); throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1"); } } // ---------------------------------------------------------------------------------------- } #endif // DLIB_MLp_KERNEL_1_