Actual source code: tomographyADMM.c

  1: #include <petsctao.h>
  2: /*
  3: Description:   ADMM tomography reconstruction example .
  4:                0.5*||Ax-b||^2 + lambda*g(x)
  5: Reference:     BRGN Tomography Example
  6: */

  8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
  9:                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
 10:                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
 11:                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
 12:                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
 13:                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
 14:                       Then, we augment both f and g, and solve it via ADMM. \n\
 15:                       D is the M*N transform matrix so that D*x is sparse. \n";

 17: typedef struct {
 18:   PetscInt  M, N, K, reg;
 19:   PetscReal lambda, eps, mumin;
 20:   Mat       A, ATA, H, Hx, D, Hz, DTD, HF;
 21:   Vec       c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
 22: } AppCtx;

 24: /*------------------------------------------------------------*/

 26: PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr)
 27: {
 28:   return 0;
 29: }

 31: /*------------------------------------------------------------*/

 33: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
 34: {
 35:   PetscReal lambda, mu;
 36:   AppCtx   *user;
 37:   Vec       out, work, y, x;
 38:   Tao       admm_tao, misfit;

 40:   user = NULL;
 41:   mu   = 0;
 42:   TaoGetADMMParentTao(tao, &admm_tao);
 43:   TaoADMMGetMisfitSubsolver(admm_tao, &misfit);
 44:   TaoADMMGetSpectralPenalty(admm_tao, &mu);
 45:   TaoShellGetContext(tao, &user);

 47:   lambda = user->lambda;
 48:   work   = user->workN;
 49:   TaoGetSolution(tao, &out);
 50:   TaoGetSolution(misfit, &x);
 51:   TaoADMMGetDualVector(admm_tao, &y);

 53:   /* Dx + y/mu */
 54:   MatMult(user->D, x, work);
 55:   VecAXPY(work, 1 / mu, y);

 57:   /* soft thresholding */
 58:   TaoSoftThreshold(work, -lambda / mu, lambda / mu, out);
 59:   return 0;
 60: }

 62: /*------------------------------------------------------------*/

 64: PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
 65: {
 66:   AppCtx *user = (AppCtx *)ptr;

 68:   /* Objective  0.5*||Ax-b||_2^2 */
 69:   MatMult(user->A, X, user->workM);
 70:   VecAXPY(user->workM, -1, user->b);
 71:   VecDot(user->workM, user->workM, f);
 72:   *f *= 0.5;
 73:   /* Gradient. ATAx-ATb */
 74:   MatMult(user->ATA, X, user->workN);
 75:   MatMultTranspose(user->A, user->b, user->workN2);
 76:   VecWAXPY(g, -1., user->workN2, user->workN);
 77:   return 0;
 78: }

 80: /*------------------------------------------------------------*/

 82: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
 83: {
 84:   AppCtx *user = (AppCtx *)ptr;

 86:   /* compute regularizer objective
 87:    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
 88:   VecCopy(X, user->workN2);
 89:   VecPow(user->workN2, 2.);
 90:   VecShift(user->workN2, user->eps * user->eps);
 91:   VecSqrtAbs(user->workN2);
 92:   VecCopy(user->workN2, user->workN3);
 93:   VecShift(user->workN2, -user->eps);
 94:   VecSum(user->workN2, f_reg);
 95:   *f_reg *= user->lambda;
 96:   /* compute regularizer gradient = lambda*x */
 97:   VecPointwiseDivide(G_reg, X, user->workN3);
 98:   VecScale(G_reg, user->lambda);
 99:   return 0;
100: }

102: /*------------------------------------------------------------*/

104: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
105: {
106:   AppCtx   *user = (AppCtx *)ptr;
107:   PetscReal temp;

109:   /* compute regularizer objective = lambda*|z|_2^2 */
110:   VecDot(X, X, &temp);
111:   *f_reg = 0.5 * user->lambda * temp;
112:   /* compute regularizer gradient = lambda*z */
113:   VecCopy(X, G_reg);
114:   VecScale(G_reg, user->lambda);
115:   return 0;
116: }

118: /*------------------------------------------------------------*/

120: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
121: {
122:   return 0;
123: }

125: /*------------------------------------------------------------*/

127: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
128: {
129:   AppCtx *user = (AppCtx *)ptr;

131:   MatMult(user->D, x, user->workN);
132:   VecPow(user->workN2, 2.);
133:   VecShift(user->workN2, user->eps * user->eps);
134:   VecSqrtAbs(user->workN2);
135:   VecShift(user->workN2, -user->eps);
136:   VecReciprocal(user->workN2);
137:   VecScale(user->workN2, user->eps * user->eps);
138:   MatDiagonalSet(H, user->workN2, INSERT_VALUES);
139:   return 0;
140: }

142: /*------------------------------------------------------------*/

144: PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
145: {
146:   AppCtx   *user = (AppCtx *)ptr;
147:   PetscReal f_reg;

149:   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
150:   MatMult(user->A, X, user->workM);
151:   VecAXPY(user->workM, -1, user->b);
152:   VecDot(user->workM, user->workM, f);
153:   VecNorm(X, NORM_2, &f_reg);
154:   *f *= 0.5;
155:   *f += user->lambda * f_reg * f_reg;
156:   /* Gradient. ATAx-ATb + 2*lambda*x */
157:   MatMult(user->ATA, X, user->workN);
158:   MatMultTranspose(user->A, user->b, user->workN2);
159:   VecWAXPY(g, -1., user->workN2, user->workN);
160:   VecAXPY(g, 2 * user->lambda, X);
161:   return 0;
162: }
163: /*------------------------------------------------------------*/

165: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
166: {
167:   return 0;
168: }
169: /*------------------------------------------------------------*/

171: PetscErrorCode InitializeUserData(AppCtx *user)
172: {
173:   char        dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
174:   PetscViewer fd;                                    /* used to load data from file */
175:   PetscInt    k, n;
176:   PetscScalar v;

178:   /* Load the A matrix, b vector, and xGT vector from a binary file. */
179:   PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd);
180:   MatCreate(PETSC_COMM_WORLD, &user->A);
181:   MatSetType(user->A, MATAIJ);
182:   MatLoad(user->A, fd);
183:   VecCreate(PETSC_COMM_WORLD, &user->b);
184:   VecLoad(user->b, fd);
185:   VecCreate(PETSC_COMM_WORLD, &user->xGT);
186:   VecLoad(user->xGT, fd);
187:   PetscViewerDestroy(&fd);

189:   MatGetSize(user->A, &user->M, &user->N);

191:   MatCreate(PETSC_COMM_WORLD, &user->D);
192:   MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N);
193:   MatSetFromOptions(user->D);
194:   MatSetUp(user->D);
195:   for (k = 0; k < user->N; k++) {
196:     v = 1.0;
197:     n = k + 1;
198:     if (k < user->N - 1) MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES);
199:     v = -1.0;
200:     MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES);
201:   }
202:   MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY);
203:   MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY);

205:   MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD);

207:   MatCreate(PETSC_COMM_WORLD, &user->Hz);
208:   MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N);
209:   MatSetFromOptions(user->Hz);
210:   MatSetUp(user->Hz);
211:   MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY);
212:   MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY);

214:   VecCreate(PETSC_COMM_WORLD, &(user->x));
215:   VecCreate(PETSC_COMM_WORLD, &(user->workM));
216:   VecCreate(PETSC_COMM_WORLD, &(user->workN));
217:   VecCreate(PETSC_COMM_WORLD, &(user->workN2));
218:   VecSetSizes(user->x, PETSC_DECIDE, user->N);
219:   VecSetSizes(user->workM, PETSC_DECIDE, user->M);
220:   VecSetSizes(user->workN, PETSC_DECIDE, user->N);
221:   VecSetSizes(user->workN2, PETSC_DECIDE, user->N);
222:   VecSetFromOptions(user->x);
223:   VecSetFromOptions(user->workM);
224:   VecSetFromOptions(user->workN);
225:   VecSetFromOptions(user->workN2);

227:   VecDuplicate(user->workN, &(user->workN3));
228:   VecDuplicate(user->x, &(user->xlb));
229:   VecDuplicate(user->x, &(user->xub));
230:   VecDuplicate(user->x, &(user->c));
231:   VecSet(user->xlb, 0.0);
232:   VecSet(user->c, 0.0);
233:   VecSet(user->xub, PETSC_INFINITY);

235:   MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA));
236:   MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx));
237:   MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF));

239:   MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY);
240:   MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY);
241:   MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY);
242:   MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY);
243:   MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY);
244:   MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY);

246:   user->lambda = 1.e-8;
247:   user->eps    = 1.e-3;
248:   user->reg    = 2;
249:   user->mumin  = 5.e-6;

251:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
252:   PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL);
253:   PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL);
254:   PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL);
255:   PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL);
256:   PetscOptionsEnd();
257:   return 0;
258: }

260: /*------------------------------------------------------------*/

262: PetscErrorCode DestroyContext(AppCtx *user)
263: {
264:   MatDestroy(&user->A);
265:   MatDestroy(&user->ATA);
266:   MatDestroy(&user->Hx);
267:   MatDestroy(&user->Hz);
268:   MatDestroy(&user->HF);
269:   MatDestroy(&user->D);
270:   MatDestroy(&user->DTD);
271:   VecDestroy(&user->xGT);
272:   VecDestroy(&user->xlb);
273:   VecDestroy(&user->xub);
274:   VecDestroy(&user->b);
275:   VecDestroy(&user->x);
276:   VecDestroy(&user->c);
277:   VecDestroy(&user->workN3);
278:   VecDestroy(&user->workN2);
279:   VecDestroy(&user->workN);
280:   VecDestroy(&user->workM);
281:   return 0;
282: }

284: /*------------------------------------------------------------*/

286: int main(int argc, char **argv)
287: {
288:   Tao         tao, misfit, reg;
289:   PetscReal   v1, v2;
290:   AppCtx     *user;
291:   PetscViewer fd;
292:   char        resultFile[] = "tomographyResult_x";

295:   PetscInitialize(&argc, &argv, (char *)0, help);
296:   PetscNew(&user);
297:   InitializeUserData(user);

299:   TaoCreate(PETSC_COMM_WORLD, &tao);
300:   TaoSetType(tao, TAOADMM);
301:   TaoSetSolution(tao, user->x);
302:   /* f(x) + g(x) for parent tao */
303:   TaoADMMSetSpectralPenalty(tao, 1.);
304:   TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user);
305:   MatShift(user->HF, user->lambda);
306:   TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user);

308:   /* f(x) for misfit tao */
309:   TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user);
310:   TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user);
311:   TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE);
312:   TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user);

314:   /* g(x) for regularizer tao */
315:   if (user->reg == 1) {
316:     TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user);
317:     TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user);
318:     TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE);
319:   } else if (user->reg == 2) {
320:     TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user);
321:     MatShift(user->Hz, 1);
322:     MatScale(user->Hz, user->lambda);
323:     TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user);
324:     TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE);

327:   /* Set type for the misfit solver */
328:   TaoADMMGetMisfitSubsolver(tao, &misfit);
329:   TaoADMMGetRegularizationSubsolver(tao, &reg);
330:   TaoSetType(misfit, TAONLS);
331:   if (user->reg == 3) {
332:     TaoSetType(reg, TAOSHELL);
333:     TaoShellSetContext(reg, (void *)user);
334:     TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold);
335:   } else {
336:     TaoSetType(reg, TAONLS);
337:   }
338:   TaoSetVariableBounds(misfit, user->xlb, user->xub);

340:   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
341:   TaoADMMSetRegularizerCoefficient(tao, user->lambda);
342:   TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user);
343:   TaoADMMSetMinimumSpectralPenalty(tao, user->mumin);

345:   TaoADMMSetConstraintVectorRHS(tao, user->c);
346:   TaoSetFromOptions(tao);
347:   TaoSolve(tao);

349:   /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
350:   PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd);
351:   VecView(user->x, fd);
352:   PetscViewerDestroy(&fd);

354:   /* compute the error */
355:   VecAXPY(user->x, -1, user->xGT);
356:   VecNorm(user->x, NORM_2, &v1);
357:   VecNorm(user->xGT, NORM_2, &v2);
358:   PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2));

360:   /* Free TAO data structures */
361:   TaoDestroy(&tao);
362:   DestroyContext(user);
363:   PetscFree(user);
364:   PetscFinalize();
365:   return 0;
366: }

368: /*TEST

370:    build:
371:       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)

373:    test:
374:       suffix: 1
375:       localrunfiles: tomographyData_A_b_xGT
376:       args:  -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc

378:    test:
379:       suffix: 2
380:       localrunfiles: tomographyData_A_b_xGT
381:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8  -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor

383:    test:
384:       suffix: 3
385:       localrunfiles: tomographyData_A_b_xGT
386:       args:  -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor

388:    test:
389:       suffix: 4
390:       localrunfiles: tomographyData_A_b_xGT
391:       args:  -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc

393:    test:
394:       suffix: 5
395:       localrunfiles: tomographyData_A_b_xGT
396:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

398:    test:
399:       suffix: 6
400:       localrunfiles: tomographyData_A_b_xGT
401:       args:  -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

403: TEST*/