vendor/scs/linsys/cpu/direct/private.c in scs-0.4.0 vs vendor/scs/linsys/cpu/direct/private.c in scs-0.4.1

- old
+ new

@@ -1,22 +1,12 @@ #include "private.h" -#include "linsys.h" -const char *SCS(get_lin_sys_method)() { - return "sparse-direct"; +const char *scs_get_lin_sys_method() { + return "sparse-direct-amd-qdldl"; } -/* -char *SCS(get_lin_sys_summary)(ScsLinSysWork *p, const ScsInfo *info) { - char *str = (char *)scs_malloc(sizeof(char) * 128); - scs_int n = p->L->n; - sprintf(str, "lin-sys: nnz(L): %li\n", (long)(p->L->p[n] + n)); - return str; -} -*/ - -void SCS(free_lin_sys_work)(ScsLinSysWork *p) { +void scs_free_lin_sys_work(ScsLinSysWork *p) { if (p) { SCS(cs_spfree)(p->L); SCS(cs_spfree)(p->kkt); scs_free(p->diag_p); scs_free(p->perm); @@ -31,159 +21,47 @@ scs_free(p->fwork); scs_free(p); } } -static csc *form_kkt(const ScsMatrix *A, const ScsMatrix *P, scs_float *diag_p, - const scs_float *diag_r, scs_int *diag_r_idxs) { - /* ONLY UPPER TRIANGULAR PART IS STUFFED - * forms column compressed kkt matrix - * assumes column compressed form A matrix - * - * forms upper triangular part of [(I + P) A'; A -I] - * P : n x n, A: m x n. - */ - scs_int h, i, j, count; - csc *Kcsc, *K; - scs_int n = A->n; - scs_int m = A->m; - scs_int Anz = A->p[n]; - scs_int Knzmax; - scs_int *idx_mapping; - if (P) { - /* Upper bound P + I upper triangular component as Pnz + n */ - Knzmax = n + m + Anz + P->p[n]; - } else { - Knzmax = n + m + Anz; - } - K = SCS(cs_spalloc)(m + n, m + n, Knzmax, 1, 1); - -#if VERBOSITY > 0 - scs_printf("forming kkt\n"); -#endif - /* Here we generate a triplet matrix and then compress to CSC */ - if (!K) { - return SCS_NULL; - } - - count = 0; /* element counter */ - if (P) { - /* R_x + P in top left */ - for (j = 0; j < n; j++) { /* cols */ - diag_p[j] = 0.; - /* empty column, add diagonal */ - if (P->p[j] == P->p[j + 1]) { - K->i[count] = j; - K->p[count] = j; - K->x[count] = diag_r[j]; - diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ - count++; - } - for (h = P->p[j]; h < P->p[j + 1]; h++) { - i = P->i[h]; /* row */ - if (i > j) { /* only upper triangular needed */ - break; - } - K->i[count] = i; - K->p[count] = j; - K->x[count] = P->x[h]; - if (i == j) { - /* P has diagonal element */ - diag_p[j] = P->x[h]; - K->x[count] += diag_r[j]; - diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ - } - count++; - /* reached the end without adding diagonal, do it now */ - if ((i < j) && (h + 1 == P->p[j + 1] || P->i[h + 1] > j)) { - K->i[count] = j; - K->p[count] = j; - K->x[count] = diag_r[j]; - diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ - count++; - } - } - } - } else { - /* R_x in top left */ - for (j = 0; j < n; j++) { - diag_p[j] = 0.; - K->i[count] = j; - K->p[count] = j; - K->x[count] = diag_r[j]; - diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ - count++; - } - } - - /* A^T at top right */ - for (j = 0; j < n; j++) { - for (h = A->p[j]; h < A->p[j + 1]; h++) { - K->p[count] = A->i[h] + n; - K->i[count] = j; - K->x[count] = A->x[h]; - count++; - } - } - - /* -R_y at bottom right */ - for (j = 0; j < m; j++) { - K->i[count] = j + n; - K->p[count] = j + n; - K->x[count] = -diag_r[j + n]; - diag_r_idxs[j + n] = count; /* store the indices where diag_r occurs */ - count++; - } - - K->nz = count; - idx_mapping = (scs_int *)scs_calloc(K->nz, sizeof(scs_int)); - Kcsc = SCS(cs_compress)(K, idx_mapping); - for (i = 0; i < m + n; i++) { - diag_r_idxs[i] = idx_mapping[diag_r_idxs[i]]; - } - SCS(cs_spfree)(K); - scs_free(idx_mapping); - return Kcsc; -} - -static scs_int _ldl_init(csc *A, scs_int *P, scs_float **info) { +static scs_int _ldl_init(ScsMatrix *A, scs_int *P, scs_float **info) { *info = (scs_float *)scs_calloc(AMD_INFO, sizeof(scs_float)); return amd_order(A->n, A->p, A->i, P, (scs_float *)SCS_NULL, *info); } /* call only once */ static scs_int ldl_prepare(ScsLinSysWork *p) { - csc *kkt = p->kkt, *L = p->L; - scs_int n = kkt->n; + ScsMatrix *kkt = p->kkt, *L = p->L; + scs_int nzmax, n = kkt->n; p->etree = (scs_int *)scs_calloc(n, sizeof(scs_int)); p->Lnz = (scs_int *)scs_calloc(n, sizeof(scs_int)); p->iwork = (scs_int *)scs_calloc(3 * n, sizeof(scs_int)); L->p = (scs_int *)scs_calloc((1 + n), sizeof(scs_int)); - L->nzmax = QDLDL_etree(n, kkt->p, kkt->i, p->iwork, p->Lnz, p->etree); - if (L->nzmax < 0) { + nzmax = QDLDL_etree(n, kkt->p, kkt->i, p->iwork, p->Lnz, p->etree); + if (nzmax < 0) { scs_printf("Error in elimination tree calculation.\n"); - if (L->nzmax == -1) { + if (nzmax == -1) { scs_printf("Matrix is not perfectly upper triangular.\n"); - } else if (L->nzmax == -2) { + } else if (nzmax == -2) { scs_printf("Integer overflow in L nonzero count.\n"); } - return L->nzmax; + return nzmax; } - L->x = (scs_float *)scs_calloc(L->nzmax, sizeof(scs_float)); - L->i = (scs_int *)scs_calloc(L->nzmax, sizeof(scs_int)); + L->x = (scs_float *)scs_calloc(nzmax, sizeof(scs_float)); + L->i = (scs_int *)scs_calloc(nzmax, sizeof(scs_int)); p->Dinv = (scs_float *)scs_calloc(n, sizeof(scs_float)); p->D = (scs_float *)scs_calloc(n, sizeof(scs_float)); p->bwork = (scs_int *)scs_calloc(n, sizeof(scs_int)); p->fwork = (scs_float *)scs_calloc(n, sizeof(scs_float)); - return L->nzmax; + return nzmax; } /* can call many times */ static scs_int ldl_factor(ScsLinSysWork *p, scs_int num_vars) { scs_int factor_status; - csc *kkt = p->kkt, *L = p->L; + ScsMatrix *kkt = p->kkt, *L = p->L; #if VERBOSITY > 0 scs_printf("numeric factorization\n"); #endif factor_status = QDLDL_factor(kkt->n, kkt->p, kkt->i, kkt->x, L->p, L->i, L->x, p->D, @@ -215,11 +93,11 @@ scs_int j; for (j = 0; j < n; j++) x[P[j]] = b[j]; } -static void _ldl_solve(scs_float *b, csc *L, scs_float *Dinv, scs_int *P, +static void _ldl_solve(scs_float *b, ScsMatrix *L, scs_float *Dinv, scs_int *P, scs_float *bp) { /* solves PLDL'P' x = b for x */ scs_int n = L->n; _ldl_perm(n, bp, b, P); QDLDL_solve(n, L->p, L->i, L->x, Dinv, bp); @@ -238,15 +116,15 @@ for (k = 0; k < n; k++) pinv[p[k]] = k; /* invert the permutation */ return pinv; /* return result */ } -static csc *cs_symperm(const csc *A, const scs_int *pinv, scs_int *idx_mapping, - scs_int values) { +static ScsMatrix *cs_symperm(const ScsMatrix *A, const scs_int *pinv, + scs_int *idx_mapping, scs_int values) { scs_int i, j, p, q, i2, j2, n, *Ap, *Ai, *Cp, *Ci, *w; scs_float *Cx, *Ax; - csc *C; + ScsMatrix *C; n = A->n; Ap = A->p; Ai = A->i; Ax = A->x; C = SCS(cs_spalloc)(n, n, Ap[n], values && (Ax != SCS_NULL), @@ -288,29 +166,31 @@ } return SCS(cs_done)(C, w, SCS_NULL, 1); /* success; free workspace, return C */ } -static csc *permute_kkt(const ScsMatrix *A, const ScsMatrix *P, - ScsLinSysWork *p, const scs_float *diag_r) { +static ScsMatrix *permute_kkt(const ScsMatrix *A, const ScsMatrix *P, + ScsLinSysWork *p, const scs_float *diag_r) { scs_float *info; - scs_int *Pinv, amd_status, *idx_mapping, i; - csc *kkt_perm, *kkt = form_kkt(A, P, p->diag_p, diag_r, p->diag_r_idxs); + scs_int *Pinv, amd_status, *idx_mapping, i, kkt_nnz; + ScsMatrix *kkt_perm; + ScsMatrix *kkt = SCS(form_kkt)(A, P, p->diag_p, diag_r, p->diag_r_idxs, 1); if (!kkt) { return SCS_NULL; } + kkt_nnz = kkt->p[kkt->n]; amd_status = _ldl_init(kkt, p->perm, &info); if (amd_status < 0) { scs_printf("AMD permutatation error.\n"); return SCS_NULL; } #if VERBOSITY > 0 scs_printf("Matrix factorization info:\n"); amd_info(info); #endif Pinv = cs_pinv(p->perm, A->n + A->m); - idx_mapping = (scs_int *)scs_calloc(kkt->nzmax, sizeof(scs_int)); + idx_mapping = (scs_int *)scs_calloc(kkt_nnz, sizeof(scs_int)); kkt_perm = cs_symperm(kkt, Pinv, idx_mapping, 1); for (i = 0; i < A->n + A->m; i++) { p->diag_r_idxs[i] = idx_mapping[p->diag_r_idxs[i]]; } SCS(cs_spfree)(kkt); @@ -318,11 +198,11 @@ scs_free(info); scs_free(idx_mapping); return kkt_perm; } -void SCS(update_lin_sys_diag_r)(ScsLinSysWork *p, const scs_float *diag_r) { +void scs_update_lin_sys_diag_r(ScsLinSysWork *p, const scs_float *diag_r) { scs_int i, ldl_status; for (i = 0; i < p->n; ++i) { /* top left is R_x + P, bottom right is -R_y */ p->kkt->x[p->diag_r_idxs[i]] = p->diag_p[i] + diag_r[i]; } @@ -337,25 +217,24 @@ /* SCS(free_lin_sys_work)(p); */ return; } } -ScsLinSysWork *SCS(init_lin_sys_work)(const ScsMatrix *A, const ScsMatrix *P, - const scs_float *diag_r) { +ScsLinSysWork *scs_init_lin_sys_work(const ScsMatrix *A, const ScsMatrix *P, + const scs_float *diag_r) { ScsLinSysWork *p = (ScsLinSysWork *)scs_calloc(1, sizeof(ScsLinSysWork)); scs_int n_plus_m = A->n + A->m, ldl_status, ldl_prepare_status; p->m = A->m; p->n = A->n; p->diag_p = (scs_float *)scs_calloc(A->n, sizeof(scs_float)); p->perm = (scs_int *)scs_calloc(sizeof(scs_int), n_plus_m); - p->L = (csc *)scs_calloc(1, sizeof(csc)); + p->L = (ScsMatrix *)scs_calloc(1, sizeof(ScsMatrix)); p->bp = (scs_float *)scs_calloc(n_plus_m, sizeof(scs_float)); p->diag_r_idxs = (scs_int *)scs_calloc(n_plus_m, sizeof(scs_int)); p->factorizations = 0; p->L->m = n_plus_m; p->L->n = n_plus_m; - p->L->nz = -1; p->kkt = permute_kkt(A, P, p, diag_r); ldl_prepare_status = ldl_prepare(p); ldl_status = ldl_factor(p, A->n); if (ldl_prepare_status < 0 || ldl_status < 0) { scs_printf("Error in LDL initial factorization.\n"); @@ -364,11 +243,11 @@ return SCS_NULL; } return p; } -scs_int SCS(solve_lin_sys)(ScsLinSysWork *p, scs_float *b, const scs_float *s, - scs_float tol) { +scs_int scs_solve_lin_sys(ScsLinSysWork *p, scs_float *b, const scs_float *s, + scs_float tol) { /* returns solution to linear system */ /* Ax = b with solution stored in b */ _ldl_solve(b, p->L, p->Dinv, p->perm, p->bp); return 0; }