vendor/scs/linsys/cpu/direct/private.c in scs-0.3.1 vs vendor/scs/linsys/cpu/direct/private.c in scs-0.3.2
- old
+ new
@@ -16,35 +16,35 @@
void SCS(free_lin_sys_work)(ScsLinSysWork *p) {
if (p) {
SCS(cs_spfree)(p->L);
SCS(cs_spfree)(p->kkt);
+ scs_free(p->diag_p);
scs_free(p->perm);
scs_free(p->Dinv);
scs_free(p->bp);
- scs_free(p->rho_y_vec_idxs);
+ scs_free(p->diag_r_idxs);
scs_free(p->Lnz);
scs_free(p->iwork);
scs_free(p->etree);
scs_free(p->D);
scs_free(p->bwork);
scs_free(p->fwork);
scs_free(p);
}
}
-static csc *form_kkt(const ScsMatrix *A, const ScsMatrix *P,
- scs_float *rho_y_vec, scs_int *rho_y_vec_idxs,
- scs_float rho_x) {
+static csc *form_kkt(const ScsMatrix *A, const ScsMatrix *P, scs_float *diag_p,
+ const scs_float *diag_r, scs_int *diag_r_idxs) {
/* ONLY UPPER TRIANGULAR PART IS STUFFED
* forms column compressed kkt matrix
* assumes column compressed form A matrix
*
* forms upper triangular part of [(I + P) A'; A -I]
* P : n x n, A: m x n.
*/
- scs_int i, j, k, kk;
+ scs_int h, i, j, count;
csc *Kcsc, *K;
scs_int n = A->n;
scs_int m = A->m;
scs_int Anz = A->p[n];
scs_int Knzmax;
@@ -63,95 +63,103 @@
/* Here we generate a triplet matrix and then compress to CSC */
if (!K) {
return SCS_NULL;
}
- kk = 0; /* element counter */
+ count = 0; /* element counter */
if (P) {
- /* I + P in top left */
- for (j = 0; j < P->n; j++) { /* cols */
+ /* R_x + P in top left */
+ for (j = 0; j < n; j++) { /* cols */
+ diag_p[j] = 0.;
/* empty column, add diagonal */
if (P->p[j] == P->p[j + 1]) {
- K->i[kk] = j;
- K->p[kk] = j;
- K->x[kk] = rho_x;
- kk++;
+ K->i[count] = j;
+ K->p[count] = j;
+ K->x[count] = diag_r[j];
+ diag_r_idxs[j] = count; /* store the indices where diag_r occurs */
+ count++;
}
- for (k = P->p[j]; k < P->p[j + 1]; k++) {
- i = P->i[k]; /* row */
+ for (h = P->p[j]; h < P->p[j + 1]; h++) {
+ i = P->i[h]; /* row */
if (i > j) { /* only upper triangular needed */
break;
}
- K->i[kk] = i;
- K->p[kk] = j;
- K->x[kk] = P->x[k];
+ K->i[count] = i;
+ K->p[count] = j;
+ K->x[count] = P->x[h];
if (i == j) {
/* P has diagonal element */
- K->x[kk] += rho_x;
+ diag_p[j] = P->x[h];
+ K->x[count] += diag_r[j];
+ diag_r_idxs[j] = count; /* store the indices where diag_r occurs */
}
- kk++;
+ count++;
/* reached the end without adding diagonal, do it now */
- if ((i < j) && (k + 1 == P->p[j + 1] || P->i[k + 1] > j)) {
- K->i[kk] = j;
- K->p[kk] = j;
- K->x[kk] = rho_x;
- kk++;
+ if ((i < j) && (h + 1 == P->p[j + 1] || P->i[h + 1] > j)) {
+ K->i[count] = j;
+ K->p[count] = j;
+ K->x[count] = diag_r[j];
+ diag_r_idxs[j] = count; /* store the indices where diag_r occurs */
+ count++;
}
}
}
} else {
- /* rho_x * I in top left */
- for (k = 0; k < A->n; k++) {
- K->i[kk] = k;
- K->p[kk] = k;
- K->x[kk] = rho_x;
- kk++;
+ /* R_x in top left */
+ for (j = 0; j < n; j++) {
+ diag_p[j] = 0.;
+ K->i[count] = j;
+ K->p[count] = j;
+ K->x[count] = diag_r[j];
+ diag_r_idxs[j] = count; /* store the indices where diag_r occurs */
+ count++;
}
}
/* A^T at top right */
for (j = 0; j < n; j++) {
- for (k = A->p[j]; k < A->p[j + 1]; k++) {
- K->p[kk] = A->i[k] + n;
- K->i[kk] = j;
- K->x[kk] = A->x[k];
- kk++;
+ for (h = A->p[j]; h < A->p[j + 1]; h++) {
+ K->p[count] = A->i[h] + n;
+ K->i[count] = j;
+ K->x[count] = A->x[h];
+ count++;
}
}
- /* -rho_y_vec * I at bottom right */
- for (k = 0; k < m; k++) {
- K->i[kk] = k + n;
- K->p[kk] = k + n;
- K->x[kk] = -rho_y_vec[k];
- rho_y_vec_idxs[k] = kk; /* store the indices where rho_y_vec occurs */
- kk++;
+ /* -R_y at bottom right */
+ for (j = 0; j < m; j++) {
+ K->i[count] = j + n;
+ K->p[count] = j + n;
+ K->x[count] = -diag_r[j + n];
+ diag_r_idxs[j + n] = count; /* store the indices where diag_r occurs */
+ count++;
}
- K->nz = kk;
- idx_mapping = (scs_int *)scs_malloc(K->nz * sizeof(scs_int));
+
+ K->nz = count;
+ idx_mapping = (scs_int *)scs_calloc(K->nz, sizeof(scs_int));
Kcsc = SCS(cs_compress)(K, idx_mapping);
- for (i = 0; i < A->m; i++) {
- rho_y_vec_idxs[i] = idx_mapping[rho_y_vec_idxs[i]];
+ for (i = 0; i < m + n; i++) {
+ diag_r_idxs[i] = idx_mapping[diag_r_idxs[i]];
}
SCS(cs_spfree)(K);
scs_free(idx_mapping);
return Kcsc;
}
static scs_int _ldl_init(csc *A, scs_int *P, scs_float **info) {
- *info = (scs_float *)scs_malloc(AMD_INFO * sizeof(scs_float));
+ *info = (scs_float *)scs_calloc(AMD_INFO, sizeof(scs_float));
return amd_order(A->n, A->p, A->i, P, (scs_float *)SCS_NULL, *info);
}
/* call only once */
static scs_int ldl_prepare(ScsLinSysWork *p) {
csc *kkt = p->kkt, *L = p->L;
scs_int n = kkt->n;
- p->etree = (scs_int *)scs_malloc(n * sizeof(scs_int));
- p->Lnz = (scs_int *)scs_malloc(n * sizeof(scs_int));
- p->iwork = (scs_int *)scs_malloc(3 * n * sizeof(scs_int));
- L->p = (scs_int *)scs_malloc((1 + n) * sizeof(scs_int));
+ p->etree = (scs_int *)scs_calloc(n, sizeof(scs_int));
+ p->Lnz = (scs_int *)scs_calloc(n, sizeof(scs_int));
+ p->iwork = (scs_int *)scs_calloc(3 * n, sizeof(scs_int));
+ L->p = (scs_int *)scs_calloc((1 + n), sizeof(scs_int));
L->nzmax = QDLDL_etree(n, kkt->p, kkt->i, p->iwork, p->Lnz, p->etree);
if (L->nzmax < 0) {
scs_printf("Error in elimination tree calculation.\n");
if (L->nzmax == -1) {
scs_printf("Matrix is not perfectly upper triangular.\n");
@@ -159,16 +167,16 @@
scs_printf("Integer overflow in L nonzero count.\n");
}
return L->nzmax;
}
- L->x = (scs_float *)scs_malloc(L->nzmax * sizeof(scs_float));
- L->i = (scs_int *)scs_malloc(L->nzmax * sizeof(scs_int));
- p->Dinv = (scs_float *)scs_malloc(n * sizeof(scs_float));
- p->D = (scs_float *)scs_malloc(n * sizeof(scs_float));
- p->bwork = (scs_int *)scs_malloc(n * sizeof(scs_int));
- p->fwork = (scs_float *)scs_malloc(n * sizeof(scs_float));
+ L->x = (scs_float *)scs_calloc(L->nzmax, sizeof(scs_float));
+ L->i = (scs_int *)scs_calloc(L->nzmax, sizeof(scs_int));
+ p->Dinv = (scs_float *)scs_calloc(n, sizeof(scs_float));
+ p->D = (scs_float *)scs_calloc(n, sizeof(scs_float));
+ p->bwork = (scs_int *)scs_calloc(n, sizeof(scs_int));
+ p->fwork = (scs_float *)scs_calloc(n, sizeof(scs_float));
return L->nzmax;
}
/* can call many times */
static scs_int ldl_factor(ScsLinSysWork *p, scs_int num_vars) {
@@ -221,11 +229,11 @@
static scs_int *cs_pinv(scs_int const *p, scs_int n) {
scs_int k, *pinv;
if (!p) {
return SCS_NULL;
} /* p = SCS_NULL denotes identity */
- pinv = (scs_int *)scs_malloc(n * sizeof(scs_int)); /* allocate result */
+ pinv = (scs_int *)scs_calloc(n, sizeof(scs_int)); /* allocate result */
if (!pinv) {
return SCS_NULL;
} /* out of memory */
for (k = 0; k < n; k++)
pinv[p[k]] = k; /* invert the permutation */
@@ -281,14 +289,14 @@
return SCS(cs_done)(C, w, SCS_NULL,
1); /* success; free workspace, return C */
}
static csc *permute_kkt(const ScsMatrix *A, const ScsMatrix *P,
- ScsLinSysWork *p, scs_float *rho_y_vec) {
+ ScsLinSysWork *p, const scs_float *diag_r) {
scs_float *info;
scs_int *Pinv, amd_status, *idx_mapping, i;
- csc *kkt_perm, *kkt = form_kkt(A, P, rho_y_vec, p->rho_y_vec_idxs, p->rho_x);
+ csc *kkt_perm, *kkt = form_kkt(A, P, p->diag_p, diag_r, p->diag_r_idxs);
if (!kkt) {
return SCS_NULL;
}
amd_status = _ldl_init(kkt, p->perm, &info);
if (amd_status < 0) {
@@ -298,51 +306,56 @@
#if VERBOSITY > 0
scs_printf("Matrix factorization info:\n");
amd_info(info);
#endif
Pinv = cs_pinv(p->perm, A->n + A->m);
- idx_mapping = (scs_int *)scs_malloc(kkt->nzmax * sizeof(scs_int));
+ idx_mapping = (scs_int *)scs_calloc(kkt->nzmax, sizeof(scs_int));
kkt_perm = cs_symperm(kkt, Pinv, idx_mapping, 1);
- for (i = 0; i < A->m; i++) {
- p->rho_y_vec_idxs[i] = idx_mapping[p->rho_y_vec_idxs[i]];
+ for (i = 0; i < A->n + A->m; i++) {
+ p->diag_r_idxs[i] = idx_mapping[p->diag_r_idxs[i]];
}
SCS(cs_spfree)(kkt);
scs_free(Pinv);
scs_free(info);
scs_free(idx_mapping);
return kkt_perm;
}
-void SCS(update_lin_sys_rho_y_vec)(ScsLinSysWork *p, scs_float *rho_y_vec) {
+void SCS(update_lin_sys_diag_r)(ScsLinSysWork *p, const scs_float *diag_r) {
scs_int i, ldl_status;
- for (i = 0; i < p->m; ++i) {
- p->kkt->x[p->rho_y_vec_idxs[i]] = -rho_y_vec[i];
+ for (i = 0; i < p->n; ++i) {
+ /* top left is R_x + P, bottom right is -R_y */
+ p->kkt->x[p->diag_r_idxs[i]] = p->diag_p[i] + diag_r[i];
}
+ for (i = p->n; i < p->n + p->m; ++i) {
+ /* top left is R_x + P, bottom right is -R_y */
+ p->kkt->x[p->diag_r_idxs[i]] = -diag_r[i];
+ }
ldl_status = ldl_factor(p, p->n);
if (ldl_status < 0) {
scs_printf("Error in LDL factorization when updating.\n");
/* TODO: this is broken somehow */
/* SCS(free_lin_sys_work)(p); */
return;
}
}
ScsLinSysWork *SCS(init_lin_sys_work)(const ScsMatrix *A, const ScsMatrix *P,
- scs_float *rho_y_vec, scs_float rho_x) {
+ const scs_float *diag_r) {
ScsLinSysWork *p = (ScsLinSysWork *)scs_calloc(1, sizeof(ScsLinSysWork));
scs_int n_plus_m = A->n + A->m, ldl_status, ldl_prepare_status;
p->m = A->m;
p->n = A->n;
- p->rho_x = rho_x;
- p->perm = (scs_int *)scs_malloc(sizeof(scs_int) * n_plus_m);
- p->L = (csc *)scs_malloc(sizeof(csc));
- p->bp = (scs_float *)scs_malloc(n_plus_m * sizeof(scs_float));
- p->rho_y_vec_idxs = (scs_int *)scs_malloc(A->m * sizeof(scs_int));
+ p->diag_p = (scs_float *)scs_calloc(A->n, sizeof(scs_float));
+ p->perm = (scs_int *)scs_calloc(sizeof(scs_int), n_plus_m);
+ p->L = (csc *)scs_calloc(1, sizeof(csc));
+ p->bp = (scs_float *)scs_calloc(n_plus_m, sizeof(scs_float));
+ p->diag_r_idxs = (scs_int *)scs_calloc(n_plus_m, sizeof(scs_int));
p->factorizations = 0;
p->L->m = n_plus_m;
p->L->n = n_plus_m;
p->L->nz = -1;
- p->kkt = permute_kkt(A, P, p, rho_y_vec);
+ p->kkt = permute_kkt(A, P, p, diag_r);
ldl_prepare_status = ldl_prepare(p);
ldl_status = ldl_factor(p, A->n);
if (ldl_prepare_status < 0 || ldl_status < 0) {
scs_printf("Error in LDL initial factorization.\n");
/* TODO: this is broken somehow */