Actual source code: tomographyADMM.c
1: #include <petsctao.h>
2: /*
3: Description: ADMM tomography reconstruction example .
4: 0.5*||Ax-b||^2 + lambda*g(x)
5: Reference: BRGN Tomography Example
6: */
8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9: A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10: We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11: g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12: natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\
13: either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14: Then, we augment both f and g, and solve it via ADMM. \n\
15: D is the M*N transform matrix so that D*x is sparse. \n";
17: typedef struct {
18: PetscInt M, N, K, reg;
19: PetscReal lambda, eps, mumin;
20: Mat A, ATA, H, Hx, D, Hz, DTD, HF;
21: Vec c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22: } AppCtx;
24: /*------------------------------------------------------------*/
26: PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr)
27: {
28: return 0;
29: }
31: /*------------------------------------------------------------*/
33: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
34: {
35: PetscReal lambda, mu;
36: AppCtx *user;
37: Vec out, work, y, x;
38: Tao admm_tao, misfit;
40: user = NULL;
41: mu = 0;
42: TaoGetADMMParentTao(tao, &admm_tao);
43: TaoADMMGetMisfitSubsolver(admm_tao, &misfit);
44: TaoADMMGetSpectralPenalty(admm_tao, &mu);
45: TaoShellGetContext(tao, &user);
47: lambda = user->lambda;
48: work = user->workN;
49: TaoGetSolution(tao, &out);
50: TaoGetSolution(misfit, &x);
51: TaoADMMGetDualVector(admm_tao, &y);
53: /* Dx + y/mu */
54: MatMult(user->D, x, work);
55: VecAXPY(work, 1 / mu, y);
57: /* soft thresholding */
58: TaoSoftThreshold(work, -lambda / mu, lambda / mu, out);
59: return 0;
60: }
62: /*------------------------------------------------------------*/
64: PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
65: {
66: AppCtx *user = (AppCtx *)ptr;
68: /* Objective 0.5*||Ax-b||_2^2 */
69: MatMult(user->A, X, user->workM);
70: VecAXPY(user->workM, -1, user->b);
71: VecDot(user->workM, user->workM, f);
72: *f *= 0.5;
73: /* Gradient. ATAx-ATb */
74: MatMult(user->ATA, X, user->workN);
75: MatMultTranspose(user->A, user->b, user->workN2);
76: VecWAXPY(g, -1., user->workN2, user->workN);
77: return 0;
78: }
80: /*------------------------------------------------------------*/
82: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
83: {
84: AppCtx *user = (AppCtx *)ptr;
86: /* compute regularizer objective
87: * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
88: VecCopy(X, user->workN2);
89: VecPow(user->workN2, 2.);
90: VecShift(user->workN2, user->eps * user->eps);
91: VecSqrtAbs(user->workN2);
92: VecCopy(user->workN2, user->workN3);
93: VecShift(user->workN2, -user->eps);
94: VecSum(user->workN2, f_reg);
95: *f_reg *= user->lambda;
96: /* compute regularizer gradient = lambda*x */
97: VecPointwiseDivide(G_reg, X, user->workN3);
98: VecScale(G_reg, user->lambda);
99: return 0;
100: }
102: /*------------------------------------------------------------*/
104: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
105: {
106: AppCtx *user = (AppCtx *)ptr;
107: PetscReal temp;
109: /* compute regularizer objective = lambda*|z|_2^2 */
110: VecDot(X, X, &temp);
111: *f_reg = 0.5 * user->lambda * temp;
112: /* compute regularizer gradient = lambda*z */
113: VecCopy(X, G_reg);
114: VecScale(G_reg, user->lambda);
115: return 0;
116: }
118: /*------------------------------------------------------------*/
120: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
121: {
122: return 0;
123: }
125: /*------------------------------------------------------------*/
127: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
128: {
129: AppCtx *user = (AppCtx *)ptr;
131: MatMult(user->D, x, user->workN);
132: VecPow(user->workN2, 2.);
133: VecShift(user->workN2, user->eps * user->eps);
134: VecSqrtAbs(user->workN2);
135: VecShift(user->workN2, -user->eps);
136: VecReciprocal(user->workN2);
137: VecScale(user->workN2, user->eps * user->eps);
138: MatDiagonalSet(H, user->workN2, INSERT_VALUES);
139: return 0;
140: }
142: /*------------------------------------------------------------*/
144: PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
145: {
146: AppCtx *user = (AppCtx *)ptr;
147: PetscReal f_reg;
149: /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
150: MatMult(user->A, X, user->workM);
151: VecAXPY(user->workM, -1, user->b);
152: VecDot(user->workM, user->workM, f);
153: VecNorm(X, NORM_2, &f_reg);
154: *f *= 0.5;
155: *f += user->lambda * f_reg * f_reg;
156: /* Gradient. ATAx-ATb + 2*lambda*x */
157: MatMult(user->ATA, X, user->workN);
158: MatMultTranspose(user->A, user->b, user->workN2);
159: VecWAXPY(g, -1., user->workN2, user->workN);
160: VecAXPY(g, 2 * user->lambda, X);
161: return 0;
162: }
163: /*------------------------------------------------------------*/
165: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
166: {
167: return 0;
168: }
169: /*------------------------------------------------------------*/
171: PetscErrorCode InitializeUserData(AppCtx *user)
172: {
173: char dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
174: PetscViewer fd; /* used to load data from file */
175: PetscInt k, n;
176: PetscScalar v;
178: /* Load the A matrix, b vector, and xGT vector from a binary file. */
179: PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd);
180: MatCreate(PETSC_COMM_WORLD, &user->A);
181: MatSetType(user->A, MATAIJ);
182: MatLoad(user->A, fd);
183: VecCreate(PETSC_COMM_WORLD, &user->b);
184: VecLoad(user->b, fd);
185: VecCreate(PETSC_COMM_WORLD, &user->xGT);
186: VecLoad(user->xGT, fd);
187: PetscViewerDestroy(&fd);
189: MatGetSize(user->A, &user->M, &user->N);
191: MatCreate(PETSC_COMM_WORLD, &user->D);
192: MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N);
193: MatSetFromOptions(user->D);
194: MatSetUp(user->D);
195: for (k = 0; k < user->N; k++) {
196: v = 1.0;
197: n = k + 1;
198: if (k < user->N - 1) MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES);
199: v = -1.0;
200: MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES);
201: }
202: MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY);
203: MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY);
205: MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD);
207: MatCreate(PETSC_COMM_WORLD, &user->Hz);
208: MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N);
209: MatSetFromOptions(user->Hz);
210: MatSetUp(user->Hz);
211: MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY);
212: MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY);
214: VecCreate(PETSC_COMM_WORLD, &(user->x));
215: VecCreate(PETSC_COMM_WORLD, &(user->workM));
216: VecCreate(PETSC_COMM_WORLD, &(user->workN));
217: VecCreate(PETSC_COMM_WORLD, &(user->workN2));
218: VecSetSizes(user->x, PETSC_DECIDE, user->N);
219: VecSetSizes(user->workM, PETSC_DECIDE, user->M);
220: VecSetSizes(user->workN, PETSC_DECIDE, user->N);
221: VecSetSizes(user->workN2, PETSC_DECIDE, user->N);
222: VecSetFromOptions(user->x);
223: VecSetFromOptions(user->workM);
224: VecSetFromOptions(user->workN);
225: VecSetFromOptions(user->workN2);
227: VecDuplicate(user->workN, &(user->workN3));
228: VecDuplicate(user->x, &(user->xlb));
229: VecDuplicate(user->x, &(user->xub));
230: VecDuplicate(user->x, &(user->c));
231: VecSet(user->xlb, 0.0);
232: VecSet(user->c, 0.0);
233: VecSet(user->xub, PETSC_INFINITY);
235: MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA));
236: MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx));
237: MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF));
239: MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY);
240: MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY);
241: MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY);
242: MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY);
243: MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY);
244: MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY);
246: user->lambda = 1.e-8;
247: user->eps = 1.e-3;
248: user->reg = 2;
249: user->mumin = 5.e-6;
251: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
252: PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL);
253: PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL);
254: PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL);
255: PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL);
256: PetscOptionsEnd();
257: return 0;
258: }
260: /*------------------------------------------------------------*/
262: PetscErrorCode DestroyContext(AppCtx *user)
263: {
264: MatDestroy(&user->A);
265: MatDestroy(&user->ATA);
266: MatDestroy(&user->Hx);
267: MatDestroy(&user->Hz);
268: MatDestroy(&user->HF);
269: MatDestroy(&user->D);
270: MatDestroy(&user->DTD);
271: VecDestroy(&user->xGT);
272: VecDestroy(&user->xlb);
273: VecDestroy(&user->xub);
274: VecDestroy(&user->b);
275: VecDestroy(&user->x);
276: VecDestroy(&user->c);
277: VecDestroy(&user->workN3);
278: VecDestroy(&user->workN2);
279: VecDestroy(&user->workN);
280: VecDestroy(&user->workM);
281: return 0;
282: }
284: /*------------------------------------------------------------*/
286: int main(int argc, char **argv)
287: {
288: Tao tao, misfit, reg;
289: PetscReal v1, v2;
290: AppCtx *user;
291: PetscViewer fd;
292: char resultFile[] = "tomographyResult_x";
295: PetscInitialize(&argc, &argv, (char *)0, help);
296: PetscNew(&user);
297: InitializeUserData(user);
299: TaoCreate(PETSC_COMM_WORLD, &tao);
300: TaoSetType(tao, TAOADMM);
301: TaoSetSolution(tao, user->x);
302: /* f(x) + g(x) for parent tao */
303: TaoADMMSetSpectralPenalty(tao, 1.);
304: TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user);
305: MatShift(user->HF, user->lambda);
306: TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user);
308: /* f(x) for misfit tao */
309: TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user);
310: TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user);
311: TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE);
312: TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user);
314: /* g(x) for regularizer tao */
315: if (user->reg == 1) {
316: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user);
317: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user);
318: TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE);
319: } else if (user->reg == 2) {
320: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user);
321: MatShift(user->Hz, 1);
322: MatScale(user->Hz, user->lambda);
323: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user);
324: TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE);
327: /* Set type for the misfit solver */
328: TaoADMMGetMisfitSubsolver(tao, &misfit);
329: TaoADMMGetRegularizationSubsolver(tao, ®);
330: TaoSetType(misfit, TAONLS);
331: if (user->reg == 3) {
332: TaoSetType(reg, TAOSHELL);
333: TaoShellSetContext(reg, (void *)user);
334: TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold);
335: } else {
336: TaoSetType(reg, TAONLS);
337: }
338: TaoSetVariableBounds(misfit, user->xlb, user->xub);
340: /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
341: TaoADMMSetRegularizerCoefficient(tao, user->lambda);
342: TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user);
343: TaoADMMSetMinimumSpectralPenalty(tao, user->mumin);
345: TaoADMMSetConstraintVectorRHS(tao, user->c);
346: TaoSetFromOptions(tao);
347: TaoSolve(tao);
349: /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
350: PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd);
351: VecView(user->x, fd);
352: PetscViewerDestroy(&fd);
354: /* compute the error */
355: VecAXPY(user->x, -1, user->xGT);
356: VecNorm(user->x, NORM_2, &v1);
357: VecNorm(user->xGT, NORM_2, &v2);
358: PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2));
360: /* Free TAO data structures */
361: TaoDestroy(&tao);
362: DestroyContext(user);
363: PetscFree(user);
364: PetscFinalize();
365: return 0;
366: }
368: /*TEST
370: build:
371: requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
373: test:
374: suffix: 1
375: localrunfiles: tomographyData_A_b_xGT
376: args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
378: test:
379: suffix: 2
380: localrunfiles: tomographyData_A_b_xGT
381: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
383: test:
384: suffix: 3
385: localrunfiles: tomographyData_A_b_xGT
386: args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
388: test:
389: suffix: 4
390: localrunfiles: tomographyData_A_b_xGT
391: args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
393: test:
394: suffix: 5
395: localrunfiles: tomographyData_A_b_xGT
396: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
398: test:
399: suffix: 6
400: localrunfiles: tomographyData_A_b_xGT
401: args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
403: TEST*/