// Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. // This code was adapted from code from the JAMA part of NIST's TNT library. // See: http://math.nist.gov/tnt/ #ifndef DLIB_MATRIX_LU_DECOMPOSITION_H #define DLIB_MATRIX_LU_DECOMPOSITION_H #include "matrix.h" #include "matrix_utilities.h" #include "matrix_subexp.h" #include "matrix_trsm.h" #include <algorithm> #ifdef DLIB_USE_LAPACK #include "lapack/getrf.h" #endif namespace dlib { template < typename matrix_exp_type > class lu_decomposition { public: const static long NR = matrix_exp_type::NR; const static long NC = matrix_exp_type::NC; typedef typename matrix_exp_type::type type; typedef typename matrix_exp_type::mem_manager_type mem_manager_type; typedef typename matrix_exp_type::layout_type layout_type; typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type; typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type; typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type; // You have supplied an invalid type of matrix_exp_type. You have // to use this object with matrices that contain float or double type data. COMPILE_TIME_ASSERT((is_same_type<float, type>::value || is_same_type<double, type>::value )); template <typename EXP> lu_decomposition ( const matrix_exp<EXP> &A ); bool is_square ( ) const; bool is_singular ( ) const; long nr( ) const; long nc( ) const; const matrix_type get_l ( ) const; const matrix_type get_u ( ) const; const pivot_column_vector_type& get_pivot ( ) const; type det ( ) const; template <typename EXP> const matrix_type solve ( const matrix_exp<EXP> &B ) const; private: /* Array for internal storage of decomposition. */ matrix<type,0,0,mem_manager_type,column_major_layout> LU; long m, n, pivsign; pivot_column_vector_type piv; }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // Public member functions // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> template <typename EXP> lu_decomposition<matrix_exp_type>:: lu_decomposition ( const matrix_exp<EXP>& A ) : LU(A), m(A.nr()), n(A.nc()) { using namespace std; using std::abs; COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); // make sure requires clause is not broken DLIB_ASSERT(A.size() > 0, "\tlu_decomposition::lu_decomposition(A)" << "\n\tInvalid inputs were given to this function" << "\n\tA.size(): " << A.size() << "\n\tthis: " << this ); #ifdef DLIB_USE_LAPACK matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp; lapack::getrf(LU, piv_temp); pivsign = 1; // Turn the piv_temp vector into a more useful form. This way we will have the identity // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat // different. piv = trans(range(0,m-1)); for (long i = 0; i < piv_temp.size(); ++i) { // -1 because FORTRAN is indexed starting with 1 instead of 0 if (piv(piv_temp(i)-1) != piv(i)) { std::swap(piv(i), piv(piv_temp(i)-1)); pivsign = -pivsign; } } #else // Use a "left-looking", dot-product, Crout/Doolittle algorithm. piv = trans(range(0,m-1)); pivsign = 1; column_vector_type LUcolj(m); // Outer loop. for (long j = 0; j < n; j++) { // Make a copy of the j-th column to localize references. LUcolj = colm(LU,j); // Apply previous transformations. for (long i = 0; i < m; i++) { // Most of the time is spent in the following dot product. const long kmax = std::min(i,j); type s; if (kmax > 0) s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax); else s = 0; LU(i,j) = LUcolj(i) -= s; } // Find pivot and exchange if necessary. long p = j; for (long i = j+1; i < m; i++) { if (abs(LUcolj(i)) > abs(LUcolj(p))) { p = i; } } if (p != j) { long k=0; for (k = 0; k < n; k++) { type t = LU(p,k); LU(p,k) = LU(j,k); LU(j,k) = t; } k = piv(p); piv(p) = piv(j); piv(j) = k; pivsign = -pivsign; } // Compute multipliers. if ((j < m) && (LU(j,j) != 0.0)) { for (long i = j+1; i < m; i++) { LU(i,j) /= LU(j,j); } } } #endif } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> bool lu_decomposition<matrix_exp_type>:: is_square ( ) const { return m == n; } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> long lu_decomposition<matrix_exp_type>:: nr ( ) const { return m; } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> long lu_decomposition<matrix_exp_type>:: nc ( ) const { return n; } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> bool lu_decomposition<matrix_exp_type>:: is_singular ( ) const { /* Is the matrix singular? if upper triangular factor U (and hence A) is singular, false otherwise. */ // make sure requires clause is not broken DLIB_ASSERT(is_square() == true, "\tbool lu_decomposition::is_singular()" << "\n\tYou can only use this on square matrices" << "\n\tthis: " << this ); type max_val, min_val; find_min_and_max (abs(diag(LU)), min_val, max_val); type eps = max_val; if (eps != 0) eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10; else eps = 1; // there is no max so just use 1 return min_val < eps; } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: get_l ( ) const { if (LU.nr() >= LU.nc()) return lowerm(LU,1.0); else return lowerm(subm(LU,0,0,m,m), 1.0); } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: get_u ( ) const { if (LU.nr() >= LU.nc()) return upperm(subm(LU,0,0,n,n)); else return upperm(LU); } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>:: get_pivot ( ) const { return piv; } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>:: det ( ) const { // make sure requires clause is not broken DLIB_ASSERT(is_square() == true, "\ttype lu_decomposition::det()" << "\n\tYou can only use this on square matrices" << "\n\tthis: " << this ); // Check if it is singular and if it is just return 0. // We want to do this because a prod() operation can easily // overcome a single diagonal element that is effectively 0 when // LU is a big enough matrix. if (is_singular()) return 0; return prod(diag(LU))*static_cast<type>(pivsign); } // ---------------------------------------------------------------------------------------- template <typename matrix_exp_type> template <typename EXP> const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: solve ( const matrix_exp<EXP> &B ) const { COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); // make sure requires clause is not broken DLIB_ASSERT(is_square() == true && B.nr() == nr(), "\ttype lu_decomposition::solve()" << "\n\tInvalid arguments to this function" << "\n\tis_square(): " << (is_square()? "true":"false" ) << "\n\tB.nr(): " << B.nr() << "\n\tnr(): " << nr() << "\n\tthis: " << this ); // Copy right hand side with pivoting matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv)); using namespace blas_bindings; // Solve L*Y = B(piv,:) triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X); // Solve U*X = Y; triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X); return X; } // ---------------------------------------------------------------------------------------- } #endif // DLIB_MATRIX_LU_DECOMPOSITION_H