ext/nmatrix/util/math.h in nmatrix-0.0.2 vs ext/nmatrix/util/math.h in nmatrix-0.0.3

- old
+ new

@@ -68,11 +68,14 @@ * Standard Includes */ extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors. #include <cblas.h> - //#include <clapack.h> + + #ifdef HAVE_CLAPACK_H + #include <clapack.h> + #endif } #include <algorithm> // std::min, std::max #include <limits> // std::numeric_limits @@ -83,21 +86,22 @@ #include "lapack.h" /* * Macros */ +#define REAL_RECURSE_LIMIT 4 /* * Data */ extern "C" { /* * C accessors. */ - void nm_math_det_exact(const int M, const void* elements, const int lda, dtype_t dtype, void* result); + void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result); void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size); void nm_math_init_blas(void); } @@ -141,17 +145,23 @@ /* * This version of trsm doesn't do any error checks and only works on column-major matrices. * * For row major, call trsm<DType> instead. That will handle necessary changes-of-variables * and parameter checks. + * + * Note that some of the boundary conditions here may be incorrect. Very little has been tested! + * This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly. */ template <typename DType> inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const DType alpha, const DType* a, const int lda, DType* b, const int ldb) { + + // (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3 + if (m == 0 || n == 0) return; /* Quick return if possible. */ if (alpha == 0) { // Handle alpha == 0 for (int j = 0; j < n; ++j) { for (int i = 0; i < m; ++i) { @@ -208,11 +218,11 @@ /* Form B := alpha*inv( A**T )*B. */ if (uplo == CblasUpper) { for (int j = 0; j < n; ++j) { for (int i = 0; i < m; ++i) { DType temp = alpha * b[i + j * ldb]; - for (int k = 0; k < i-1; ++k) { + for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably. temp -= a[k + i * lda] * b[k + j * ldb]; } if (diag == CblasNonUnit) { temp /= a[i + i * lda]; } @@ -337,20 +347,109 @@ } } } +template <typename DType> +inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) { + rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes"); +} + +template <typename DType> +inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) { + rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes"); +} + +template <> +inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) { + cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc); +} + +template <> +inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) { + cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc); +} + +template <> +inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) { + cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc); +} + +template <> +inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) { + cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc); +} + + +template <> +inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) { + cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc); +} + +template <> +inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N, + const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) { + cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc); +} + + +template <typename DType> +inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, + const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha, + const DType* A, const int lda, DType* B, const int ldb) { + rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes"); +} + +template <> +inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, + const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha, + const float* A, const int lda, float* B, const int ldb) { + cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb); +} + +template <> +inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, + const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha, + const double* A, const int lda, double* B, const int ldb) { + cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb); +} + +template <> +inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, + const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha, + const Complex64* A, const int lda, Complex64* B, const int ldb) { + cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb); +} + +template <> +inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, + const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha, + const Complex128* A, const int lda, Complex128* B, const int ldb) { + cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb); +} + + /* * BLAS' DTRSM function, generalized. */ template <typename DType, typename = typename std::enable_if<!std::is_integral<DType>::value>::type> inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const DType alpha, const DType* a, const int lda, DType* b, const int ldb) { + /*using std::cerr; + using std::endl;*/ + int num_rows_a = n; if (side == CblasLeft) num_rows_a = m; if (lda < std::max(1,num_rows_a)) { fprintf(stderr, "TRSM: num_rows_a = %d; got lda=%d\n", num_rows_a, lda); @@ -366,19 +465,32 @@ // For row major, need to switch side and uplo enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft; enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper; +/* + cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ") + << (uplo_ == CblasUpper ? "upper " : "lower ") + << (trans_a == CblasTrans ? "trans " : "notrans ") + << (diag == CblasNonUnit ? "nonunit " : "unit ") + << n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl; +*/ trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb); } else { // CblasColMajor if (ldb < std::max(1,m)) { fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb); rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)"); } - +/* + cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ") + << (uplo == CblasUpper ? "upper " : "lower ") + << (trans_a == CblasTrans ? "trans " : "notrans ") + << (diag == CblasNonUnit ? "nonunit " : "unit ") + << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl; +*/ trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb); } } @@ -388,50 +500,58 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const float alpha, const float* a, const int lda, float* b, const int ldb) { - cblas_strsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb); + cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb); } template <> inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const double alpha, const double* a, const int lda, double* b, const int ldb) { - cblas_dtrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb); +/* using std::cerr; + using std::endl; + cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ") + << (uplo == CblasUpper ? "upper " : "lower ") + << (trans_a == CblasTrans ? "trans " : "notrans ") + << (diag == CblasNonUnit ? "nonunit " : "unit ") + << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl; +*/ + cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb); } template <> inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64 alpha, const Complex64* a, const int lda, Complex64* b, const int ldb) { - cblas_ctrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb); + cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb); } template <> inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128 alpha, const Complex128* a, const int lda, Complex128* b, const int ldb) { - cblas_ztrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb); + cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb); } /* * ATLAS function which performs row interchanges on a general rectangular matrix. Modeled after the LAPACK LASWP function. * * This version is templated for use by template <> getrf(). */ template <typename DType> inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) { - const int n = K2 - K1; + //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused. int nb = N >> 5; const int mr = N - (nb<<5); const int incA = lda << 5; @@ -1259,12 +1379,93 @@ } return(ierr); } +/* + * Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF. + * + * From ATLAS 3.8.0. + */ +template <typename DType> +int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A, + const int lda, const int* ipiv, DType* B, const int ldb) +{ + // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src. + if (!N || !NRHS) return 0; + + const DType ONE = 1; + + if (Order == CblasColMajor) { + if (Trans == CblasNoTrans) { + nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1); + nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + } else { + nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb); + nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1); + } + } else { + if (Trans == CblasNoTrans) { + nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb); + nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1); + } else { + nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1); + nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + } + } + return 0; +} + + /* + * Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF. + * + * From ATLAS 3.8.0. + */ +template <typename DType, bool is_complex> +int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A, + const int lda, DType* B, const int ldb) +{ + // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src. + + CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans; + + if (!N || !NRHS) return 0; + + const DType ONE = 1; + + if (Order == CblasColMajor) { + if (Uplo == CblasUpper) { + nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + } else { + nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb); + } + } else { + // There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only + // worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us. + + if (Uplo == CblasUpper) { + nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + } else { + nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb); + } + } + return 0; +} + + + +/* * From ATLAS 3.8.0: * * Computes one of two LU factorizations based on the setting of the Order * parameter, as follows: * ---------------------------------------------------------------------------- @@ -1315,13 +1516,624 @@ //rb_raise(rb_eNotImpError, "column major getrf not implemented"); } } +/* + * From ATLAS 3.8.0: + * + * Computes one of two LU factorizations based on the setting of the Order + * parameter, as follows: + * ---------------------------------------------------------------------------- + * Order == CblasColMajor + * Column-major factorization of form + * A = P * L * U + * where P is a row-permutation matrix, L is lower triangular with unit + * diagonal elements (lower trapazoidal if M > N), and U is upper triangular + * (upper trapazoidal if M < N). + * + * ---------------------------------------------------------------------------- + * Order == CblasRowMajor + * Row-major factorization of form + * A = P * L * U + * where P is a column-permutation matrix, L is lower triangular (lower + * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper + * trapazoidal if M < N). + * + * ============================================================================ + * Let IERR be the return value of the function: + * If IERR == 0, successful exit. + * If (IERR < 0) the -IERR argument had an illegal value + * If (IERR > 0 && Order == CblasColMajor) + * U(i-1,i-1) is exactly zero. The factorization has been completed, + * but the factor U is exactly singular, and division by zero will + * occur if it is used to solve a system of equations. + * If (IERR > 0 && Order == CblasRowMajor) + * L(i-1,i-1) is exactly zero. The factorization has been completed, + * but the factor L is exactly singular, and division by zero will + * occur if it is used to solve a system of equations. + */ +template <typename DType> +inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) { +#ifdef HAVE_CLAPACK_H + rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes"); +#else + rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far"); +#endif + return 0; +} +#ifdef HAVE_CLAPACK_H +template <> +inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) { + return clapack_spotrf(order, uplo, N, A, lda); +} +template <> +inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) { + return clapack_dpotrf(order, uplo, N, A, lda); +} + +template <> +inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) { + return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda); +} + +template <> +inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) { + return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda); +} +#endif + + +// This is the old BLAS version of this function. ATLAS has an optimized version, but +// it's going to be tough to translate. +template <typename DType> +static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) { + if (N > 0) { + int ix = 0, iy = 0; + for (int i = 0; i < N; ++i) { + DType temp = X[i]; + X[i] = Y[i]; + Y[i] = temp; + + ix += incX; + iy += incY; + } + } +} + + +// Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied. +// +// From ATLAS 3.8.0. +template <typename DType> +static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) { + + for (int i = 0; i != M; ++i) { + for (int j = i+1; j < N; ++j) { + C[j] = U[j]; + U[j] = 0; + } + + C += ldc; + U += ldu; + } +} + + /* + * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived + * functions. This is probably really complicated. + * + * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU. + * + */ + +/* + +template <bool RowMajor, bool Upper, typename DType> +static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) { + + if (RowMajor) { + DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda; + DType tmp; + if (Upper) { + DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3], + A12 = pA1[2], A13 = pA1[3], + A23 = pA2[3]; + + if (Diag == CblasNonUnit) { + pA0->inverse(); + (pA1+1)->inverse(); + (pA2+2)->inverse(); + (pA3+3)->inverse(); + + pA0[1] = -A01 * pA1[1] * pA0[0]; + pA1[2] = -A12 * pA2[2] * pA1[1]; + pA2[3] = -A23 * pA3[3] * pA2[2]; + + pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0]; + pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1]; + + pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0]; + + } else { + + pA0[1] = -A01; + pA1[2] = -A12; + pA2[3] = -A23; + + pA0[2] = -(A01 * pA1[2] + A02); + pA1[3] = -(A12 * pA2[3] + A13); + + pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03); + } + + } else { // Lower + DType A10 = pA1[0], + A20 = pA2[0], A21 = pA2[1], + A30 = PA3[0], A31 = pA3[1], A32 = pA3[2]; + DType *B10 = pA1, + *B20 = pA2, + *B30 = pA3, + *B21 = pA2+1, + *B31 = pA3+1, + *B32 = pA3+2; + + + if (Diag == CblasNonUnit) { + pA0->inverse(); + (pA1+1)->inverse(); + (pA2+2)->inverse(); + (pA3+3)->inverse(); + + *B10 = -A10 * pA0[0] * pA1[1]; + *B21 = -A21 * pA1[1] * pA2[2]; + *B32 = -A32 * pA2[2] * pA3[3]; + *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2]; + *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3]; + *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3; + } else { + *B10 = -A10; + *B21 = -A21; + *B32 = -A32; + *B20 = -(A20 + A21 * (*B10)); + *B31 = -(A31 + A32 * (*B21)); + *B30 = -(A30 + A31 * (*B10) + A32 * (*B20)); + } + } + + } else { + rb_raise(rb_eNotImpError, "only row-major implemented at this time"); + } + + return 0; + +} + + +template <bool RowMajor, bool Upper, typename DType> +static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) { + + if (RowMajor) { + + DType tmp; + + if (Upper) { + DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3], + A12 = pA1[2], A13 = pA1[3]; + + DType *B01 = pA0 + 1, + *B02 = pA0 + 2, + *B12 = pA1 + 2; + + if (Diag == CblasNonUnit) { + pA0->inverse(); + (pA1+1)->inverse(); + (pA2+2)->inverse(); + + *B01 = -A01 * pA1[1] * pA0[0]; + *B12 = -A12 * pA2[2] * pA1[1]; + *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0]; + } else { + *B01 = -A01; + *B12 = -A12; + *B02 = -(A01 * (*B12) + A02); + } + + } else { // Lower + DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda; + DType A10=pA1[0], + A20=pA2[0], A21=pA2[1]; + + DType *B10 = pA1, + *B20 = pA2; + *B21 = pA2+1; + + if (Diag == CblasNonUnit) { + pA0->inverse(); + (pA1+1)->inverse(); + (pA2+2)->inverse(); + *B10 = -A10 * pA0[0] * pA1[1]; + *B21 = -A21 * pA1[1] * pA2[2]; + *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2]; + } else { + *B10 = -A10; + *B21 = -A21; + *B20 = -(A20 + A21 * (*B10)); + } + } + + + } else { + rb_raise(rb_eNotImpError, "only row-major implemented at this time"); + } + + return 0; + +} + +template <bool RowMajor, bool Upper, bool Real, typename DType> +static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) { + DType *Age, *Atr; + DType tmp; + int Nleft, Nright; + + int ierr = 0; + + static const DType ONE = 1; + static const DType MONE -1; + static const DType NONE = -1; + + if (RowMajor) { + + // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1) + if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) { + Nleft = N >> 1; +#ifdef NB + if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft)); +#endif + + Nright = N - Nleft; + + if (Upper) { + Age = A + Nleft; + Atr = A + (Nleft * (lda+1)); + + nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag, + Nleft, Nright, ONE, Atr, lda, Age, lda); + + nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag, + Nleft, Nright, MONE, A, lda, Age, lda); + + } else { // Lower + Age = A + ((Nleft*lda)); + Atr = A + (Nleft * (lda+1)); + + nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag, + Nright, Nleft, ONE, A, lda, Age, lda); + nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag, + Nright, Nleft, MONE, Atr, lda, Age, lda); + } + + ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda); + if (ierr) return ierr; + + ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda); + if (ierr) return ierr + Nleft; + + } else { + if (Real) { + if (N == 4) { + return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda); + } else if (N == 3) { + return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda); + } else if (N == 2) { + if (Diag == CblasNonUnit) { + A->inverse(); + (A+(lda+1))->inverse(); + + if (Upper) { + *(A+1) *= *A; // TRI_MUL + *(A+1) *= *(A+lda+1); // TRI_MUL + } else { + *(A+lda) *= *A; // TRI_MUL + *(A+lda) *= *(A+lda+1); // TRI_MUL + } + } + + if (Upper) *(A+1) = -*(A+1); // TRI_NEG + else *(A+lda) = -*(A+lda); // TRI_NEG + } else if (Diag == CblasNonUnit) A->inverse(); + } else { // not real + if (Diag == CblasNonUnit) A->inverse(); + } + } + + } else { + rb_raise(rb_eNotImpError, "only row-major implemented at this time"); + } + + return ierr; +} + + +template <bool RowMajor, bool Real, typename DType> +int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) { + + if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time"); + + int jb, nb, I, ndown, iret; + + const DType ONE = 1, NONE = -1; + + int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda); + if (!iret && N > 1) { + jb = lwrk / N; + if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb)); + else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU; + else nb = jb; + if (!nb) return -6; // need at least 1 row of workspace + + // only first iteration will have partial block, unroll it + + jb = N - (N/nb) * nb; + if (!jb) jb = nb; + I = N - jb; + A += lda * I; + trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb); + nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit, + jb, N, ONE, wrk, jb, A, lda); + + if (I) { + do { + I -= nb; + A -= nb * lda; + ndown = N-I; + trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown); + nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit, + nb, N, ONE, wrk, ndown, A, lda); + } while (I); + } + + // Apply row interchanges + + for (I = N - 2; I >= 0; --I) { + jb = ipiv[I]; + if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1); + } + } + + return iret; +} +*/ + + +// TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like +// TODO: it may actually be the same one. +// +// This function is called ATL_rot in ATLAS 3.8.4. +template <typename DType> +inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) { + if (c != 1 || s != 0) { + if (incX == 1 && incY == 1) { + for (int i = 0; i != N; ++i) { + DType tmp = X[i] * c + Y[i] * s; + Y[i] = Y[i] * c - X[i] * s; + X[i] = tmp; + } + } else { + for (int i = N; i > 0; --i, Y += incY, X += incX) { + DType tmp = *X * c + *Y * s; + *Y = *Y * c - *X * s; + *X = tmp; + } + } + } +} + + +/* Givens plane rotation. From ATLAS 3.8.4. */ +// FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the +// FIXME: sqrt. +template <typename DType> +inline void rotg(DType* a, DType* b, DType* c, DType* s) { + DType aa = std::abs(*a), ab = std::abs(*b); + DType roe = aa > ab ? *a : *b; + DType scal = aa + ab; + + if (scal == 0) { + *c = 1; + *s = *a = *b = 0; + } else { + DType t0 = aa / scal, t1 = ab / scal; + DType r = scal * std::sqrt(t0 * t0 + t1 * t1); + if (roe < 0) r = -r; + *c = *a / r; + *s = *b / r; + DType z = (*c != 0) ? (1 / *c) : DType(1); + *a = r; + *b = z; + } +} + +template <> +inline void rotg(float* a, float* b, float* c, float* s) { + cblas_srotg(a, b, c, s); +} + +template <> +inline void rotg(double* a, double* b, double* c, double* s) { + cblas_drotg(a, b, c, s); +} + +template <> +inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) { + cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s)); +} + +template <> +inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) { + cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s)); +} + +template <typename DType> +inline void cblas_rotg(void* a, void* b, void* c, void* s) { + rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s)); +} + + +/* Applies a plane rotation. From ATLAS 3.8.4. */ +template <typename DType, typename CSDType> +inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) { + DType *x = X, *y = Y; + int incx = incX, incy = incY; + + if (N > 0) { + if (incX < 0) { + if (incY < 0) { incx = -incx; incy = -incy; } + else x += -incX * (N-1); + } else if (incY < 0) { + incy = -incy; + incx = -incx; + x += (N-1) * incX; + } + rot_helper<DType>(N, x, incx, y, incy, c, s); + } +} + +template <> +inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) { + cblas_srot(N, X, incX, Y, incY, (float)c, (float)s); +} + +template <> +inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) { + cblas_drot(N, X, incX, Y, incY, c, s); +} + +template <> +inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) { + cblas_csrot(N, X, incX, Y, incY, c, s); +} + +template <> +inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) { + cblas_zdrot(N, X, incX, Y, incY, c, s); +} + + +template <typename DType, typename CSDType> +inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) { + rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY, *reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s)); +} + + +template <bool is_complex, typename DType> +inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) { + + int Nleft, Nright; + const DType ONE = 1; + DType *G, *U0 = A, *U1; + + if (N > 1) { + Nleft = N >> 1; + #ifdef NB + if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft)); + #endif + + Nright = N - Nleft; + + // FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now. + if (uplo == CblasUpper) { + if (order == CblasRowMajor) { + G = A + Nleft; + U1 = G + Nleft * lda; + } else { + G = A + Nleft * lda; + U1 = G + Nleft; + } + } else { + if (order == CblasRowMajor) { + G = A + Nleft * lda; + U1 = G + Nleft; + } else { + G = A + Nleft; + U1 = G + Nleft * lda; + } + } + + lauum<is_complex, DType>(order, uplo, Nleft, U0, lda); + + if (is_complex) { + + nm::math::herk<DType>(order, uplo, + uplo == CblasLower ? CblasConjTrans : CblasNoTrans, + Nleft, Nright, &ONE, G, lda, &ONE, U0, lda); + + nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda); + } else { + nm::math::syrk<DType>(order, uplo, + uplo == CblasLower ? CblasTrans : CblasNoTrans, + Nleft, Nright, &ONE, G, lda, &ONE, U0, lda); + + nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda); + } + lauum<is_complex, DType>(order, uplo, Nright, U1, lda); + + } else { + *A = *A * *A; + } +} + + +#ifdef HAVE_CLAPACK_H +template <bool is_complex> +inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) { + clapack_slauum(order, uplo, N, A, lda); +} + +template <bool is_complex> +inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) { + clapack_dlauum(order, uplo, N, A, lda); +} + +template <bool is_complex> +inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) { + clapack_clauum(order, uplo, N, A, lda); +} + +template <bool is_complex> +inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) { + clapack_zlauum(order, uplo, N, A, lda); +} +#endif + + +/* +* Function signature conversion for calling LAPACK's lauum functions as directly as possible. +* +* For documentation: http://www.netlib.org/lapack/double/dlauum.f +* +* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. +*/ +template <bool is_complex, typename DType> +inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { + if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n); + if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n); + + if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); + else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); + + return 0; +} + + + + +/* * Macro for declaring LAPACK specializations of the getrf function. * * type is the DType; call is the specific function to call; cast_as is what the DType* should be * cast to in order to pass it to LAPACK. */ @@ -1353,9 +2165,146 @@ */ template <typename DType> inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n, void* a, const int lda, int* ipiv) { return getrf<DType>(order, m, n, reinterpret_cast<DType*>(a), lda, ipiv); } + + +/* +* Function signature conversion for calling LAPACK's potrf functions as directly as possible. +* +* For documentation: http://www.netlib.org/lapack/double/dpotrf.f +* +* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. +*/ +template <typename DType> +inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { + return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); +} + + +/* +* Function signature conversion for calling LAPACK's getrs functions as directly as possible. +* +* For documentation: http://www.netlib.org/lapack/double/dgetrs.f +* +* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. +*/ +template <typename DType> +inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs, + const void* a, const int lda, const int* ipiv, void* b, const int ldb) { + return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb); +} + +/* +* Function signature conversion for calling LAPACK's potrs functions as directly as possible. +* +* For documentation: http://www.netlib.org/lapack/double/dpotrs.f +* +* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. +*/ +template <typename DType, bool is_complex> +inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs, + const void* a, const int lda, void* b, const int ldb) { + return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb); +} + +template <typename DType> +inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) { + rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes"); + return 0; +} + +#ifdef HAVE_CLAPACK_H +template <> +inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) { + return clapack_sgetri(order, n, a, lda, ipiv); +} + +template <> +inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) { + return clapack_dgetri(order, n, a, lda, ipiv); +} + +template <> +inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) { + return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv); +} + +template <> +inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) { + return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv); +} +#endif + + +template <typename DType> +inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) { + rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes"); + return 0; +} + + +#ifdef HAVE_CLAPACK_H +template <> +inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) { + return clapack_spotri(order, uplo, n, a, lda); +} + +template <> +inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) { + return clapack_dpotri(order, uplo, n, a, lda); +} + +template <> +inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) { + return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda); +} + +template <> +inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) { + return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda); +} +#endif + +/* + * Function signature conversion for calling LAPACK's getri functions as directly as possible. + * + * For documentation: http://www.netlib.org/lapack/double/dgetri.f + * + * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. + */ +template <typename DType> +inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) { + return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv); +} + + +/* + * Function signature conversion for calling LAPACK's potri functions as directly as possible. + * + * For documentation: http://www.netlib.org/lapack/double/dpotri.f + * + * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. + */ +template <typename DType> +inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) { + return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda); +} + + +/* +* Function signature conversion for calling LAPACK's laswp functions as directly as possible. +* +* For documentation: http://www.netlib.org/lapack/double/dlaswp.f +* +* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp. +*/ +template <typename DType> +inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) { + laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx); +} + }} // end namespace nm::math