// Copyright (C) 2008 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_ #define DLIB_MATRIx_DEFAULT_MULTIPLY_ #include "../geometry/rectangle.h" #include "matrix.h" #include "matrix_utilities.h" #include "../enable_if.h" namespace dlib { // ------------------------------------------------------------------------------------ namespace ma { template < typename EXP, typename enable = void > struct matrix_is_vector { static const bool value = false; }; template < typename EXP > struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; }; } // ------------------------------------------------------------------------------------ /*! This file defines the default_matrix_multiply() function. It is a function that conforms to the following definition: template < typename matrix_dest_type, typename EXP1, typename EXP2 > void default_matrix_multiply ( matrix_dest_type& dest, const EXP1& lhs, const EXP2& rhs ); requires - (lhs*rhs).destructively_aliases(dest) == false - dest.nr() == (lhs*rhs).nr() - dest.nc() == (lhs*rhs).nc() ensures - #dest == dest + lhs*rhs !*/ // ------------------------------------------------------------------------------------ template < typename matrix_dest_type, typename EXP1, typename EXP2 > typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type default_matrix_multiply ( matrix_dest_type& dest, const EXP1& lhs, const EXP2& rhs ) { matrix_assign_default(dest, lhs*rhs, 1, true); } // ------------------------------------------------------------------------------------ template < typename matrix_dest_type, typename EXP1, typename EXP2 > typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type default_matrix_multiply ( matrix_dest_type& dest, const EXP1& lhs, const EXP2& rhs ) { const long bs = 90; // if the matrices are small enough then just use the simple multiply algorithm if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) ) { matrix_assign_default(dest, lhs*rhs, 1, true); } else { // if the lhs and rhs matrices are big enough we should use a cache friendly // algorithm that computes the matrix multiply in blocks. // Loop over all the blocks in the lhs matrix for (long r = 0; r < lhs.nr(); r+=bs) { for (long c = 0; c < lhs.nc(); c+=bs) { // make a rect for the block from lhs rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1)); // now loop over all the rhs blocks we have to multiply with the current lhs block for (long i = 0; i < rhs.nc(); i += bs) { // make a rect for the block from rhs rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1)); // make a target rect in res rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom()); // This loop is optimized assuming that the data is laid out in // row major order in memory. for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r) { for (long c = lhs_block.left(); c<= lhs_block.right(); ++c) { const typename EXP2::type temp = lhs(r,c); for (long i = rhs_block.left(); i <= rhs_block.right(); ++i) { dest(r,i) += rhs(c,i)*temp; } } } } } } } } // ------------------------------------------------------------------------------------ } #endif // DLIB_MATRIx_DEFAULT_MULTIPLY_