ext/nmatrix/math/math.h in nmatrix-0.1.0 vs ext/nmatrix/math/math.h in nmatrix-0.2.0

- old
+ new

@@ -66,24 +66,12 @@ /* * Standard Includes */ -extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors. -#if defined HAVE_CBLAS_H - #include <cblas.h> -#elif defined HAVE_ATLAS_CBLAS_H - #include <atlas/cblas.h> -#endif +#include "cblas_enums.h" -#if defined HAVE_CLAPACK_H - #include <clapack.h> -#elif defined HAVE_ATLAS_CLAPACK_H - #include <atlas/clapack.h> -#endif -} - #include <algorithm> // std::min, std::max #include <limits> // std::numeric_limits /* * Project Includes @@ -101,15 +89,22 @@ extern "C" { /* * C accessors. */ - void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result); - void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype); - void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype); + void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size); void nm_math_init_blas(void); + + /* + * Pure math implementations. + */ + void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv); + void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype); + void nm_math_hessenberg(VALUE a); + void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result); + void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype); } namespace nm { namespace math { @@ -121,98 +116,10 @@ /* * Functions */ - -template <typename DType> -inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) { - rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes"); -} - -template <typename DType> -inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) { - rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes"); -} - -template <> -inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) { - cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc); -} - -template <> -inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) { - cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc); -} - -template <> -inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) { - cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc); -} - -template <> -inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) { - cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc); -} - - -template <> -inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) { - cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc); -} - -template <> -inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, - const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) { - cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc); -} - - -template <typename DType> -inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha, - const DType* A, const int lda, DType* B, const int ldb) { - rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes"); -} - -template <> -inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha, - const float* A, const int lda, float* B, const int ldb) { - cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb); -} - -template <> -inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha, - const double* A, const int lda, double* B, const int ldb) { - cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb); -} - -template <> -inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha, - const Complex64* A, const int lda, Complex64* B, const int ldb) { - cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb); -} - -template <> -inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha, - const Complex128* A, const int lda, Complex128* B, const int ldb) { - cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb); -} - - - // Yale: numeric matrix multiply c=a*b template <typename DType> inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga, const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) { const unsigned int max_lmn = std::max(std::max(m, n), l); @@ -500,78 +407,10 @@ } } } -/* - * From ATLAS 3.8.0: - * - * Computes one of two LU factorizations based on the setting of the Order - * parameter, as follows: - * ---------------------------------------------------------------------------- - * Order == CblasColMajor - * Column-major factorization of form - * A = P * L * U - * where P is a row-permutation matrix, L is lower triangular with unit - * diagonal elements (lower trapazoidal if M > N), and U is upper triangular - * (upper trapazoidal if M < N). - * - * ---------------------------------------------------------------------------- - * Order == CblasRowMajor - * Row-major factorization of form - * A = P * L * U - * where P is a column-permutation matrix, L is lower triangular (lower - * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper - * trapazoidal if M < N). - * - * ============================================================================ - * Let IERR be the return value of the function: - * If IERR == 0, successful exit. - * If (IERR < 0) the -IERR argument had an illegal value - * If (IERR > 0 && Order == CblasColMajor) - * U(i-1,i-1) is exactly zero. The factorization has been completed, - * but the factor U is exactly singular, and division by zero will - * occur if it is used to solve a system of equations. - * If (IERR > 0 && Order == CblasRowMajor) - * L(i-1,i-1) is exactly zero. The factorization has been completed, - * but the factor L is exactly singular, and division by zero will - * occur if it is used to solve a system of equations. - */ -template <typename DType> -inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) { -#if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H - rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes"); -#else - rb_raise(rb_eNotImpError, "only CLAPACK version implemented thus far"); -#endif - return 0; -} - -#if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H -template <> -inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) { - return clapack_spotrf(order, uplo, N, A, lda); -} - -template <> -inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) { - return clapack_dpotrf(order, uplo, N, A, lda); -} - -template <> -inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) { - return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda); -} - -template <> -inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) { - return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda); -} -#endif - - - // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied. // // From ATLAS 3.8.0. template <typename DType> static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) { @@ -873,115 +712,11 @@ return iret; } */ - - -template <bool is_complex, typename DType> -inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) { - - int Nleft, Nright; - const DType ONE = 1; - DType *G, *U0 = A, *U1; - - if (N > 1) { - Nleft = N >> 1; - #ifdef NB - if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft)); - #endif - - Nright = N - Nleft; - - // FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now. - if (uplo == CblasUpper) { - if (order == CblasRowMajor) { - G = A + Nleft; - U1 = G + Nleft * lda; - } else { - G = A + Nleft * lda; - U1 = G + Nleft; - } - } else { - if (order == CblasRowMajor) { - G = A + Nleft * lda; - U1 = G + Nleft; - } else { - G = A + Nleft; - U1 = G + Nleft * lda; - } - } - - lauum<is_complex, DType>(order, uplo, Nleft, U0, lda); - - if (is_complex) { - - nm::math::herk<DType>(order, uplo, - uplo == CblasLower ? CblasConjTrans : CblasNoTrans, - Nleft, Nright, &ONE, G, lda, &ONE, U0, lda); - - nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda); - } else { - nm::math::syrk<DType>(order, uplo, - uplo == CblasLower ? CblasTrans : CblasNoTrans, - Nleft, Nright, &ONE, G, lda, &ONE, U0, lda); - - nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda); - } - lauum<is_complex, DType>(order, uplo, Nright, U1, lda); - - } else { - *A = *A * *A; - } -} - - -#if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H -template <bool is_complex> -inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) { - clapack_slauum(order, uplo, N, A, lda); -} - -template <bool is_complex> -inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) { - clapack_dlauum(order, uplo, N, A, lda); -} - -template <bool is_complex> -inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) { - clapack_clauum(order, uplo, N, A, lda); -} - -template <bool is_complex> -inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) { - clapack_zlauum(order, uplo, N, A, lda); -} -#endif - - /* -* Function signature conversion for calling LAPACK's lauum functions as directly as possible. -* -* For documentation: http://www.netlib.org/lapack/double/dlauum.f -* -* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. -*/ -template <bool is_complex, typename DType> -inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { - if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n); - if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n); - - if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); - else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); - - return 0; -} - - - - -/* * Macro for declaring LAPACK specializations of the getrf function. * * type is the DType; call is the specific function to call; cast_as is what the DType* should be * cast to in order to pass it to LAPACK. */ @@ -1000,70 +735,9 @@ /*LAPACK_GETRF(float, clapack_sgetrf, float) LAPACK_GETRF(double, clapack_dgetrf, double) LAPACK_GETRF(Complex64, clapack_cgetrf, void) LAPACK_GETRF(Complex128, clapack_zgetrf, void) */ - - - -/* -* Function signature conversion for calling LAPACK's potrf functions as directly as possible. -* -* For documentation: http://www.netlib.org/lapack/double/dpotrf.f -* -* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. -*/ -template <typename DType> -inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { - return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); -} - - - -template <typename DType> -inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) { - rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes"); - return 0; -} - - -#if defined HAVE_CLAPACK_H || defined HAVE_ATLAS_CLAPACK_H -template <> -inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) { - return clapack_spotri(order, uplo, n, a, lda); -} - -template <> -inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) { - return clapack_dpotri(order, uplo, n, a, lda); -} - -template <> -inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) { - return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda); -} - -template <> -inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) { - return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda); -} -#endif - - -/* - * Function signature conversion for calling LAPACK's potri functions as directly as possible. - * - * For documentation: http://www.netlib.org/lapack/double/dpotri.f - * - * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. - */ -template <typename DType> -inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { - return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); -} - - - }} // end namespace nm::math #endif // MATH_H