Actual source code: snespatch.c
1: /*
2: Defines a SNES that can consist of a collection of SNESes on patches of the domain
3: */
4: #include <petsc/private/vecimpl.h>
5: #include <petsc/private/snesimpl.h>
6: #include <petsc/private/pcpatchimpl.h>
7: #include <petscsf.h>
8: #include <petscsection.h>
10: typedef struct {
11: PC pc; /* The linear patch preconditioner */
12: } SNES_Patch;
14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
15: {
16: PC pc = (PC)ctx;
17: PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
18: PetscInt pt, size, i;
19: const PetscInt *indices;
20: const PetscScalar *X;
21: PetscScalar *XWithAll;
24: /* scatter from x to patch->patchStateWithAll[pt] */
25: pt = pcpatch->currentPatch;
26: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
28: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
29: VecGetArrayRead(x, &X);
30: VecGetArray(pcpatch->patchStateWithAll, &XWithAll);
32: for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
34: VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
35: VecRestoreArrayRead(x, &X);
36: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
38: PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt);
39: return 0;
40: }
42: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
43: {
44: PC pc = (PC)ctx;
45: PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
46: PetscInt pt, size, i;
47: const PetscInt *indices;
48: const PetscScalar *X;
49: PetscScalar *XWithAll;
51: /* scatter from x to patch->patchStateWithAll[pt] */
52: pt = pcpatch->currentPatch;
53: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
55: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
56: VecGetArrayRead(x, &X);
57: VecGetArray(pcpatch->patchStateWithAll, &XWithAll);
59: for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
61: VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
62: VecRestoreArrayRead(x, &X);
63: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
65: PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE);
66: return 0;
67: }
69: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
70: {
71: PC_PATCH *patch = (PC_PATCH *)pc->data;
72: const char *prefix;
73: PetscInt i, pStart, dof, maxDof = -1;
75: if (!pc->setupcalled) {
76: PetscMalloc1(patch->npatch, &patch->solver);
77: PCGetOptionsPrefix(pc, &prefix);
78: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
79: for (i = 0; i < patch->npatch; ++i) {
80: SNES snes;
82: SNESCreate(PETSC_COMM_SELF, &snes);
83: SNESSetOptionsPrefix(snes, prefix);
84: SNESAppendOptionsPrefix(snes, "sub_");
85: PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2);
86: patch->solver[i] = (PetscObject)snes;
88: PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof);
89: maxDof = PetscMax(maxDof, dof);
90: }
91: VecDuplicate(patch->localUpdate, &patch->localState);
92: VecDuplicate(patch->patchRHS, &patch->patchResidual);
93: VecDuplicate(patch->patchUpdate, &patch->patchState);
95: VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll);
96: VecSetUp(patch->patchStateWithAll);
97: }
98: for (i = 0; i < patch->npatch; ++i) {
99: SNES snes = (SNES)patch->solver[i];
101: SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc);
102: SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
103: }
104: if (!pc->setupcalled && patch->optionsSet)
105: for (i = 0; i < patch->npatch; ++i) SNESSetFromOptions((SNES)patch->solver[i]);
106: return 0;
107: }
109: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
110: {
111: PC_PATCH *patch = (PC_PATCH *)pc->data;
112: PetscInt pStart, n;
114: patch->currentPatch = i;
115: PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);
117: /* Scatter the overlapped global state to our patch state vector */
118: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
119: PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
120: PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);
122: MatGetLocalSize(patch->mat[i], NULL, &n);
123: patch->patchState->map->n = n;
124: patch->patchState->map->N = n;
125: patchUpdate->map->n = n;
126: patchUpdate->map->N = n;
127: patchRHS->map->n = n;
128: patchRHS->map->N = n;
129: /* Set initial guess to be current state*/
130: VecCopy(patch->patchState, patchUpdate);
131: /* Solve for new state */
132: SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate);
133: /* To compute update, subtract off previous state */
134: VecAXPY(patchUpdate, -1.0, patch->patchState);
136: PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
137: return 0;
138: }
140: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
141: {
142: PC_PATCH *patch = (PC_PATCH *)pc->data;
143: PetscInt i;
145: if (patch->solver) {
146: for (i = 0; i < patch->npatch; ++i) SNESReset((SNES)patch->solver[i]);
147: }
149: VecDestroy(&patch->patchResidual);
150: VecDestroy(&patch->patchState);
151: VecDestroy(&patch->patchStateWithAll);
153: VecDestroy(&patch->localState);
154: return 0;
155: }
157: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
158: {
159: PC_PATCH *patch = (PC_PATCH *)pc->data;
160: PetscInt i;
162: if (patch->solver) {
163: for (i = 0; i < patch->npatch; ++i) SNESDestroy((SNES *)&patch->solver[i]);
164: PetscFree(patch->solver);
165: }
166: return 0;
167: }
169: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
170: {
171: PC_PATCH *patch = (PC_PATCH *)pc->data;
173: PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
174: return 0;
175: }
177: static PetscErrorCode SNESSetUp_Patch(SNES snes)
178: {
179: SNES_Patch *patch = (SNES_Patch *)snes->data;
180: DM dm;
181: Mat dummy;
182: Vec F;
183: PetscInt n, N;
185: SNESGetDM(snes, &dm);
186: PCSetDM(patch->pc, dm);
187: SNESGetFunction(snes, &F, NULL, NULL);
188: VecGetLocalSize(F, &n);
189: VecGetSize(F, &N);
190: MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy);
191: PCSetOperators(patch->pc, dummy, dummy);
192: MatDestroy(&dummy);
193: PCSetUp(patch->pc);
194: /* allocate workspace */
195: return 0;
196: }
198: static PetscErrorCode SNESReset_Patch(SNES snes)
199: {
200: SNES_Patch *patch = (SNES_Patch *)snes->data;
202: PCReset(patch->pc);
203: return 0;
204: }
206: static PetscErrorCode SNESDestroy_Patch(SNES snes)
207: {
208: SNES_Patch *patch = (SNES_Patch *)snes->data;
210: SNESReset_Patch(snes);
211: PCDestroy(&patch->pc);
212: PetscFree(snes->data);
213: return 0;
214: }
216: static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject)
217: {
218: SNES_Patch *patch = (SNES_Patch *)snes->data;
219: const char *prefix;
221: PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
222: PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
223: PCSetFromOptions(patch->pc);
224: return 0;
225: }
227: static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
228: {
229: SNES_Patch *patch = (SNES_Patch *)snes->data;
230: PetscBool iascii;
232: PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii);
233: if (iascii) PetscViewerASCIIPrintf(viewer, "SNESPATCH\n");
234: PetscViewerASCIIPushTab(viewer);
235: PCView(patch->pc, viewer);
236: PetscViewerASCIIPopTab(viewer);
237: return 0;
238: }
240: static PetscErrorCode SNESSolve_Patch(SNES snes)
241: {
242: SNES_Patch *patch = (SNES_Patch *)snes->data;
243: PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data;
244: SNESLineSearch ls;
245: Vec rhs, update, state, residual;
246: const PetscScalar *globalState = NULL;
247: PetscScalar *localState = NULL;
248: PetscInt its = 0;
249: PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
251: SNESGetSolution(snes, &state);
252: SNESGetSolutionUpdate(snes, &update);
253: SNESGetRhs(snes, &rhs);
255: SNESGetFunction(snes, &residual, NULL, NULL);
256: SNESGetLineSearch(snes, &ls);
258: SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
259: VecSet(update, 0.0);
260: SNESComputeFunction(snes, state, residual);
262: VecNorm(state, NORM_2, &xnorm);
263: VecNorm(residual, NORM_2, &fnorm);
264: snes->ttol = fnorm * snes->rtol;
266: if (snes->ops->converged) {
267: PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
268: } else {
269: SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL);
270: }
271: SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
272: SNESMonitor(snes, its, fnorm);
274: /* The main solver loop */
275: for (its = 0; its < snes->max_its; its++) {
276: SNESSetIterationNumber(snes, its);
278: /* Scatter state vector to overlapped vector on all patches.
279: The vector pcpatch->localState is scattered to each patch
280: in PCApply_PATCH_Nonlinear. */
281: VecGetArrayRead(state, &globalState);
282: VecGetArray(pcpatch->localState, &localState);
283: PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE);
284: PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE);
285: VecRestoreArray(pcpatch->localState, &localState);
286: VecRestoreArrayRead(state, &globalState);
288: /* The looping over patches happens here */
289: PCApply(patch->pc, rhs, update);
291: /* Apply a line search. This will often be basic with
292: damping = 1/(max number of patches a dof can be in),
293: but not always */
294: VecScale(update, -1.0);
295: SNESLineSearchApply(ls, state, residual, &fnorm, update);
297: VecNorm(state, NORM_2, &xnorm);
298: VecNorm(update, NORM_2, &ynorm);
300: if (snes->ops->converged) {
301: PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
302: } else {
303: SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL);
304: }
305: SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
306: SNESMonitor(snes, its, fnorm);
307: }
309: if (its == snes->max_its) SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT);
310: return 0;
311: }
313: /*MC
314: SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches
316: Level: intermediate
318: References:
319: . * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
321: .seealso: `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
322: M*/
323: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
324: {
325: SNES_Patch *patch;
326: PC_PATCH *patchpc;
327: SNESLineSearch linesearch;
329: PetscNew(&patch);
331: snes->ops->solve = SNESSolve_Patch;
332: snes->ops->setup = SNESSetUp_Patch;
333: snes->ops->reset = SNESReset_Patch;
334: snes->ops->destroy = SNESDestroy_Patch;
335: snes->ops->setfromoptions = SNESSetFromOptions_Patch;
336: snes->ops->view = SNESView_Patch;
338: SNESGetLineSearch(snes, &linesearch);
339: if (!((PetscObject)linesearch)->type_name) SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC);
340: snes->usesksp = PETSC_FALSE;
342: snes->alwayscomputesfinalresidual = PETSC_FALSE;
344: snes->data = (void *)patch;
345: PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc);
346: PCSetType(patch->pc, PCPATCH);
348: patchpc = (PC_PATCH *)patch->pc->data;
349: patchpc->classname = "snes";
350: patchpc->isNonlinear = PETSC_TRUE;
352: patchpc->setupsolver = PCSetUp_PATCH_Nonlinear;
353: patchpc->applysolver = PCApply_PATCH_Nonlinear;
354: patchpc->resetsolver = PCReset_PATCH_Nonlinear;
355: patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
356: patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
358: return 0;
359: }
361: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
362: {
363: SNES_Patch *patch = (SNES_Patch *)snes->data;
364: DM dm;
366: SNESGetDM(snes, &dm);
368: PCSetDM(patch->pc, dm);
369: PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
370: return 0;
371: }
373: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
374: {
375: SNES_Patch *patch = (SNES_Patch *)snes->data;
377: PCPatchSetComputeOperator(patch->pc, func, ctx);
378: return 0;
379: }
381: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
382: {
383: SNES_Patch *patch = (SNES_Patch *)snes->data;
385: PCPatchSetComputeFunction(patch->pc, func, ctx);
386: return 0;
387: }
389: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
390: {
391: SNES_Patch *patch = (SNES_Patch *)snes->data;
393: PCPatchSetConstructType(patch->pc, ctype, func, ctx);
394: return 0;
395: }
397: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
398: {
399: SNES_Patch *patch = (SNES_Patch *)snes->data;
401: PCPatchSetCellNumbering(patch->pc, cellNumbering);
402: return 0;
403: }