Actual source code: bqnk.c
1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
2: #include <petscksp.h>
4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5: {
6: TAO_BNK *bnk = (TAO_BNK *)tao->data;
7: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8: PetscReal gnorm2, delta;
10: /* Alias the LMVM matrix into the TAO hessian */
11: if (tao->hessian) MatDestroy(&tao->hessian);
12: if (tao->hessian_pre) MatDestroy(&tao->hessian_pre);
13: PetscObjectReference((PetscObject)bqnk->B);
14: tao->hessian = bqnk->B;
15: PetscObjectReference((PetscObject)bqnk->B);
16: tao->hessian_pre = bqnk->B;
17: /* Update the Hessian with the latest solution */
18: if (bqnk->is_spd) {
19: gnorm2 = bnk->gnorm * bnk->gnorm;
20: if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
21: if (bnk->f == 0.0) {
22: delta = 2.0 / gnorm2;
23: } else {
24: delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
25: }
26: MatLMVMSymBroydenSetDelta(bqnk->B, delta);
27: }
28: MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient);
29: MatLMVMResetShift(tao->hessian);
30: /* Prepare the reduced sub-matrices for the inactive set */
31: MatDestroy(&bnk->H_inactive);
32: if (bnk->active_idx) {
33: MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive);
34: PCLMVMSetIS(bqnk->pc, bnk->inactive_idx);
35: } else {
36: PetscObjectReference((PetscObject)tao->hessian);
37: bnk->H_inactive = tao->hessian;
38: PCLMVMClearIS(bqnk->pc);
39: }
40: MatDestroy(&bnk->Hpre_inactive);
41: PetscObjectReference((PetscObject)bnk->H_inactive);
42: bnk->Hpre_inactive = bnk->H_inactive;
43: return 0;
44: }
46: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
47: {
48: TAO_BNK *bnk = (TAO_BNK *)tao->data;
49: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
51: TaoBNKComputeStep(tao, shift, ksp_reason, step_type);
52: if (*ksp_reason < 0) {
53: /* Krylov solver failed to converge so reset the LMVM matrix */
54: MatLMVMReset(bqnk->B, PETSC_FALSE);
55: MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient);
56: }
57: return 0;
58: }
60: PetscErrorCode TaoSolve_BQNK(Tao tao)
61: {
62: TAO_BNK *bnk = (TAO_BNK *)tao->data;
63: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
64: Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data;
65: Mat_LMVM *J0;
66: Mat_SymBrdn *diag_ctx;
67: PetscBool flg = PETSC_FALSE;
69: if (!tao->recycle) {
70: MatLMVMReset(bqnk->B, PETSC_FALSE);
71: lmvm->nresets = 0;
72: if (lmvm->J0) {
73: PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg);
74: if (flg) {
75: J0 = (Mat_LMVM *)lmvm->J0->data;
76: J0->nresets = 0;
77: }
78: }
79: flg = PETSC_FALSE;
80: PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "");
81: if (flg) {
82: diag_ctx = (Mat_SymBrdn *)lmvm->ctx;
83: J0 = (Mat_LMVM *)diag_ctx->D->data;
84: J0->nresets = 0;
85: }
86: }
87: (*bqnk->solve)(tao);
88: return 0;
89: }
91: PetscErrorCode TaoSetUp_BQNK(Tao tao)
92: {
93: TAO_BNK *bnk = (TAO_BNK *)tao->data;
94: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
95: PetscInt n, N;
96: PetscBool is_lmvm, is_set, is_sym;
98: TaoSetUp_BNK(tao);
99: VecGetLocalSize(tao->solution, &n);
100: VecGetSize(tao->solution, &N);
101: MatSetSizes(bqnk->B, n, n, N, N);
102: MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient);
103: PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm);
105: MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym);
107: KSPGetPC(tao->ksp, &bqnk->pc);
108: PCSetType(bqnk->pc, PCLMVM);
109: PCLMVMSetMatLMVM(bqnk->pc, bqnk->B);
110: return 0;
111: }
113: static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems *PetscOptionsObject)
114: {
115: TAO_BNK *bnk = (TAO_BNK *)tao->data;
116: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
117: PetscBool is_set;
119: TaoSetFromOptions_BNK(tao, PetscOptionsObject);
120: if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
121: MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix);
122: MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_");
123: MatSetFromOptions(bqnk->B);
124: MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd);
125: if (!is_set) bqnk->is_spd = PETSC_FALSE;
126: return 0;
127: }
129: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
130: {
131: TAO_BNK *bnk = (TAO_BNK *)tao->data;
132: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
133: PetscBool isascii;
135: TaoView_BNK(tao, viewer);
136: PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii);
137: if (isascii) {
138: PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO);
139: MatView(bqnk->B, viewer);
140: PetscViewerPopFormat(viewer);
141: }
142: return 0;
143: }
145: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
146: {
147: TAO_BNK *bnk = (TAO_BNK *)tao->data;
148: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
150: MatDestroy(&bnk->Hpre_inactive);
151: MatDestroy(&bnk->H_inactive);
152: MatDestroy(&bqnk->B);
153: PetscFree(bnk->ctx);
154: TaoDestroy_BNK(tao);
155: return 0;
156: }
158: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
159: {
160: TAO_BNK *bnk;
161: TAO_BQNK *bqnk;
163: TaoCreate_BNK(tao);
164: tao->ops->solve = TaoSolve_BQNK;
165: tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
166: tao->ops->destroy = TaoDestroy_BQNK;
167: tao->ops->view = TaoView_BQNK;
168: tao->ops->setup = TaoSetUp_BQNK;
170: bnk = (TAO_BNK *)tao->data;
171: bnk->computehessian = TaoBQNKComputeHessian;
172: bnk->computestep = TaoBQNKComputeStep;
173: bnk->init_type = BNK_INIT_DIRECTION;
175: PetscNew(&bqnk);
176: bnk->ctx = (void *)bqnk;
177: bqnk->is_spd = PETSC_TRUE;
179: MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B);
180: PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1);
181: MatSetType(bqnk->B, MATLMVMSR1);
182: return 0;
183: }
185: /*@
186: TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
187: only for quasi-Newton family of methods.
189: Input Parameters:
190: . tao - Tao solver context
192: Output Parameters:
193: . B - LMVM matrix
195: Level: advanced
197: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
198: @*/
199: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
200: {
201: TAO_BNK *bnk = (TAO_BNK *)tao->data;
202: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
203: PetscBool flg = PETSC_FALSE;
205: PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
207: *B = bqnk->B;
208: return 0;
209: }
211: /*@
212: TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
213: only for quasi-Newton family of methods.
215: QN family of methods create their own LMVM matrices and users who wish to
216: manipulate this matrix should use TaoGetLMVMMatrix() instead.
218: Input Parameters:
219: + tao - Tao solver context
220: - B - LMVM matrix
222: Level: advanced
224: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
225: @*/
226: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
227: {
228: TAO_BNK *bnk = (TAO_BNK *)tao->data;
229: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
230: PetscBool flg = PETSC_FALSE;
232: PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
234: PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg);
236: if (bqnk->B) MatDestroy(&bqnk->B);
237: PetscObjectReference((PetscObject)B);
238: bqnk->B = B;
239: return 0;
240: }