vendor/scs/linsys/cpu/direct/private.c in scs-0.3.1 vs vendor/scs/linsys/cpu/direct/private.c in scs-0.3.2

- old
+ new

@@ -16,35 +16,35 @@ void SCS(free_lin_sys_work)(ScsLinSysWork *p) { if (p) { SCS(cs_spfree)(p->L); SCS(cs_spfree)(p->kkt); + scs_free(p->diag_p); scs_free(p->perm); scs_free(p->Dinv); scs_free(p->bp); - scs_free(p->rho_y_vec_idxs); + scs_free(p->diag_r_idxs); scs_free(p->Lnz); scs_free(p->iwork); scs_free(p->etree); scs_free(p->D); scs_free(p->bwork); scs_free(p->fwork); scs_free(p); } } -static csc *form_kkt(const ScsMatrix *A, const ScsMatrix *P, - scs_float *rho_y_vec, scs_int *rho_y_vec_idxs, - scs_float rho_x) { +static csc *form_kkt(const ScsMatrix *A, const ScsMatrix *P, scs_float *diag_p, + const scs_float *diag_r, scs_int *diag_r_idxs) { /* ONLY UPPER TRIANGULAR PART IS STUFFED * forms column compressed kkt matrix * assumes column compressed form A matrix * * forms upper triangular part of [(I + P) A'; A -I] * P : n x n, A: m x n. */ - scs_int i, j, k, kk; + scs_int h, i, j, count; csc *Kcsc, *K; scs_int n = A->n; scs_int m = A->m; scs_int Anz = A->p[n]; scs_int Knzmax; @@ -63,95 +63,103 @@ /* Here we generate a triplet matrix and then compress to CSC */ if (!K) { return SCS_NULL; } - kk = 0; /* element counter */ + count = 0; /* element counter */ if (P) { - /* I + P in top left */ - for (j = 0; j < P->n; j++) { /* cols */ + /* R_x + P in top left */ + for (j = 0; j < n; j++) { /* cols */ + diag_p[j] = 0.; /* empty column, add diagonal */ if (P->p[j] == P->p[j + 1]) { - K->i[kk] = j; - K->p[kk] = j; - K->x[kk] = rho_x; - kk++; + K->i[count] = j; + K->p[count] = j; + K->x[count] = diag_r[j]; + diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ + count++; } - for (k = P->p[j]; k < P->p[j + 1]; k++) { - i = P->i[k]; /* row */ + for (h = P->p[j]; h < P->p[j + 1]; h++) { + i = P->i[h]; /* row */ if (i > j) { /* only upper triangular needed */ break; } - K->i[kk] = i; - K->p[kk] = j; - K->x[kk] = P->x[k]; + K->i[count] = i; + K->p[count] = j; + K->x[count] = P->x[h]; if (i == j) { /* P has diagonal element */ - K->x[kk] += rho_x; + diag_p[j] = P->x[h]; + K->x[count] += diag_r[j]; + diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ } - kk++; + count++; /* reached the end without adding diagonal, do it now */ - if ((i < j) && (k + 1 == P->p[j + 1] || P->i[k + 1] > j)) { - K->i[kk] = j; - K->p[kk] = j; - K->x[kk] = rho_x; - kk++; + if ((i < j) && (h + 1 == P->p[j + 1] || P->i[h + 1] > j)) { + K->i[count] = j; + K->p[count] = j; + K->x[count] = diag_r[j]; + diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ + count++; } } } } else { - /* rho_x * I in top left */ - for (k = 0; k < A->n; k++) { - K->i[kk] = k; - K->p[kk] = k; - K->x[kk] = rho_x; - kk++; + /* R_x in top left */ + for (j = 0; j < n; j++) { + diag_p[j] = 0.; + K->i[count] = j; + K->p[count] = j; + K->x[count] = diag_r[j]; + diag_r_idxs[j] = count; /* store the indices where diag_r occurs */ + count++; } } /* A^T at top right */ for (j = 0; j < n; j++) { - for (k = A->p[j]; k < A->p[j + 1]; k++) { - K->p[kk] = A->i[k] + n; - K->i[kk] = j; - K->x[kk] = A->x[k]; - kk++; + for (h = A->p[j]; h < A->p[j + 1]; h++) { + K->p[count] = A->i[h] + n; + K->i[count] = j; + K->x[count] = A->x[h]; + count++; } } - /* -rho_y_vec * I at bottom right */ - for (k = 0; k < m; k++) { - K->i[kk] = k + n; - K->p[kk] = k + n; - K->x[kk] = -rho_y_vec[k]; - rho_y_vec_idxs[k] = kk; /* store the indices where rho_y_vec occurs */ - kk++; + /* -R_y at bottom right */ + for (j = 0; j < m; j++) { + K->i[count] = j + n; + K->p[count] = j + n; + K->x[count] = -diag_r[j + n]; + diag_r_idxs[j + n] = count; /* store the indices where diag_r occurs */ + count++; } - K->nz = kk; - idx_mapping = (scs_int *)scs_malloc(K->nz * sizeof(scs_int)); + + K->nz = count; + idx_mapping = (scs_int *)scs_calloc(K->nz, sizeof(scs_int)); Kcsc = SCS(cs_compress)(K, idx_mapping); - for (i = 0; i < A->m; i++) { - rho_y_vec_idxs[i] = idx_mapping[rho_y_vec_idxs[i]]; + for (i = 0; i < m + n; i++) { + diag_r_idxs[i] = idx_mapping[diag_r_idxs[i]]; } SCS(cs_spfree)(K); scs_free(idx_mapping); return Kcsc; } static scs_int _ldl_init(csc *A, scs_int *P, scs_float **info) { - *info = (scs_float *)scs_malloc(AMD_INFO * sizeof(scs_float)); + *info = (scs_float *)scs_calloc(AMD_INFO, sizeof(scs_float)); return amd_order(A->n, A->p, A->i, P, (scs_float *)SCS_NULL, *info); } /* call only once */ static scs_int ldl_prepare(ScsLinSysWork *p) { csc *kkt = p->kkt, *L = p->L; scs_int n = kkt->n; - p->etree = (scs_int *)scs_malloc(n * sizeof(scs_int)); - p->Lnz = (scs_int *)scs_malloc(n * sizeof(scs_int)); - p->iwork = (scs_int *)scs_malloc(3 * n * sizeof(scs_int)); - L->p = (scs_int *)scs_malloc((1 + n) * sizeof(scs_int)); + p->etree = (scs_int *)scs_calloc(n, sizeof(scs_int)); + p->Lnz = (scs_int *)scs_calloc(n, sizeof(scs_int)); + p->iwork = (scs_int *)scs_calloc(3 * n, sizeof(scs_int)); + L->p = (scs_int *)scs_calloc((1 + n), sizeof(scs_int)); L->nzmax = QDLDL_etree(n, kkt->p, kkt->i, p->iwork, p->Lnz, p->etree); if (L->nzmax < 0) { scs_printf("Error in elimination tree calculation.\n"); if (L->nzmax == -1) { scs_printf("Matrix is not perfectly upper triangular.\n"); @@ -159,16 +167,16 @@ scs_printf("Integer overflow in L nonzero count.\n"); } return L->nzmax; } - L->x = (scs_float *)scs_malloc(L->nzmax * sizeof(scs_float)); - L->i = (scs_int *)scs_malloc(L->nzmax * sizeof(scs_int)); - p->Dinv = (scs_float *)scs_malloc(n * sizeof(scs_float)); - p->D = (scs_float *)scs_malloc(n * sizeof(scs_float)); - p->bwork = (scs_int *)scs_malloc(n * sizeof(scs_int)); - p->fwork = (scs_float *)scs_malloc(n * sizeof(scs_float)); + L->x = (scs_float *)scs_calloc(L->nzmax, sizeof(scs_float)); + L->i = (scs_int *)scs_calloc(L->nzmax, sizeof(scs_int)); + p->Dinv = (scs_float *)scs_calloc(n, sizeof(scs_float)); + p->D = (scs_float *)scs_calloc(n, sizeof(scs_float)); + p->bwork = (scs_int *)scs_calloc(n, sizeof(scs_int)); + p->fwork = (scs_float *)scs_calloc(n, sizeof(scs_float)); return L->nzmax; } /* can call many times */ static scs_int ldl_factor(ScsLinSysWork *p, scs_int num_vars) { @@ -221,11 +229,11 @@ static scs_int *cs_pinv(scs_int const *p, scs_int n) { scs_int k, *pinv; if (!p) { return SCS_NULL; } /* p = SCS_NULL denotes identity */ - pinv = (scs_int *)scs_malloc(n * sizeof(scs_int)); /* allocate result */ + pinv = (scs_int *)scs_calloc(n, sizeof(scs_int)); /* allocate result */ if (!pinv) { return SCS_NULL; } /* out of memory */ for (k = 0; k < n; k++) pinv[p[k]] = k; /* invert the permutation */ @@ -281,14 +289,14 @@ return SCS(cs_done)(C, w, SCS_NULL, 1); /* success; free workspace, return C */ } static csc *permute_kkt(const ScsMatrix *A, const ScsMatrix *P, - ScsLinSysWork *p, scs_float *rho_y_vec) { + ScsLinSysWork *p, const scs_float *diag_r) { scs_float *info; scs_int *Pinv, amd_status, *idx_mapping, i; - csc *kkt_perm, *kkt = form_kkt(A, P, rho_y_vec, p->rho_y_vec_idxs, p->rho_x); + csc *kkt_perm, *kkt = form_kkt(A, P, p->diag_p, diag_r, p->diag_r_idxs); if (!kkt) { return SCS_NULL; } amd_status = _ldl_init(kkt, p->perm, &info); if (amd_status < 0) { @@ -298,51 +306,56 @@ #if VERBOSITY > 0 scs_printf("Matrix factorization info:\n"); amd_info(info); #endif Pinv = cs_pinv(p->perm, A->n + A->m); - idx_mapping = (scs_int *)scs_malloc(kkt->nzmax * sizeof(scs_int)); + idx_mapping = (scs_int *)scs_calloc(kkt->nzmax, sizeof(scs_int)); kkt_perm = cs_symperm(kkt, Pinv, idx_mapping, 1); - for (i = 0; i < A->m; i++) { - p->rho_y_vec_idxs[i] = idx_mapping[p->rho_y_vec_idxs[i]]; + for (i = 0; i < A->n + A->m; i++) { + p->diag_r_idxs[i] = idx_mapping[p->diag_r_idxs[i]]; } SCS(cs_spfree)(kkt); scs_free(Pinv); scs_free(info); scs_free(idx_mapping); return kkt_perm; } -void SCS(update_lin_sys_rho_y_vec)(ScsLinSysWork *p, scs_float *rho_y_vec) { +void SCS(update_lin_sys_diag_r)(ScsLinSysWork *p, const scs_float *diag_r) { scs_int i, ldl_status; - for (i = 0; i < p->m; ++i) { - p->kkt->x[p->rho_y_vec_idxs[i]] = -rho_y_vec[i]; + for (i = 0; i < p->n; ++i) { + /* top left is R_x + P, bottom right is -R_y */ + p->kkt->x[p->diag_r_idxs[i]] = p->diag_p[i] + diag_r[i]; } + for (i = p->n; i < p->n + p->m; ++i) { + /* top left is R_x + P, bottom right is -R_y */ + p->kkt->x[p->diag_r_idxs[i]] = -diag_r[i]; + } ldl_status = ldl_factor(p, p->n); if (ldl_status < 0) { scs_printf("Error in LDL factorization when updating.\n"); /* TODO: this is broken somehow */ /* SCS(free_lin_sys_work)(p); */ return; } } ScsLinSysWork *SCS(init_lin_sys_work)(const ScsMatrix *A, const ScsMatrix *P, - scs_float *rho_y_vec, scs_float rho_x) { + const scs_float *diag_r) { ScsLinSysWork *p = (ScsLinSysWork *)scs_calloc(1, sizeof(ScsLinSysWork)); scs_int n_plus_m = A->n + A->m, ldl_status, ldl_prepare_status; p->m = A->m; p->n = A->n; - p->rho_x = rho_x; - p->perm = (scs_int *)scs_malloc(sizeof(scs_int) * n_plus_m); - p->L = (csc *)scs_malloc(sizeof(csc)); - p->bp = (scs_float *)scs_malloc(n_plus_m * sizeof(scs_float)); - p->rho_y_vec_idxs = (scs_int *)scs_malloc(A->m * sizeof(scs_int)); + p->diag_p = (scs_float *)scs_calloc(A->n, sizeof(scs_float)); + p->perm = (scs_int *)scs_calloc(sizeof(scs_int), n_plus_m); + p->L = (csc *)scs_calloc(1, sizeof(csc)); + p->bp = (scs_float *)scs_calloc(n_plus_m, sizeof(scs_float)); + p->diag_r_idxs = (scs_int *)scs_calloc(n_plus_m, sizeof(scs_int)); p->factorizations = 0; p->L->m = n_plus_m; p->L->n = n_plus_m; p->L->nz = -1; - p->kkt = permute_kkt(A, P, p, rho_y_vec); + p->kkt = permute_kkt(A, P, p, diag_r); ldl_prepare_status = ldl_prepare(p); ldl_status = ldl_factor(p, A->n); if (ldl_prepare_status < 0 || ldl_status < 0) { scs_printf("Error in LDL initial factorization.\n"); /* TODO: this is broken somehow */