Actual source code: snespatch.c

  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4: #include <petsc/private/vecimpl.h>
  5: #include <petsc/private/snesimpl.h>
  6: #include <petsc/private/pcpatchimpl.h>
  7: #include <petscsf.h>
  8: #include <petscsection.h>

 10: typedef struct {
 11:   PC pc; /* The linear patch preconditioner */
 12: } SNES_Patch;

 14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 15: {
 16:   PC                 pc      = (PC)ctx;
 17:   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
 18:   PetscInt           pt, size, i;
 19:   const PetscInt    *indices;
 20:   const PetscScalar *X;
 21:   PetscScalar       *XWithAll;


 24:   /* scatter from x to patch->patchStateWithAll[pt] */
 25:   pt = pcpatch->currentPatch;
 26:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 28:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 29:   VecGetArrayRead(x, &X);
 30:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 32:   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];

 34:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 35:   VecRestoreArrayRead(x, &X);
 36:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 38:   PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt);
 39:   return 0;
 40: }

 42: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 43: {
 44:   PC                 pc      = (PC)ctx;
 45:   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
 46:   PetscInt           pt, size, i;
 47:   const PetscInt    *indices;
 48:   const PetscScalar *X;
 49:   PetscScalar       *XWithAll;

 51:   /* scatter from x to patch->patchStateWithAll[pt] */
 52:   pt = pcpatch->currentPatch;
 53:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 55:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 56:   VecGetArrayRead(x, &X);
 57:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 59:   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];

 61:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 62:   VecRestoreArrayRead(x, &X);
 63:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 65:   PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE);
 66:   return 0;
 67: }

 69: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 70: {
 71:   PC_PATCH   *patch = (PC_PATCH *)pc->data;
 72:   const char *prefix;
 73:   PetscInt    i, pStart, dof, maxDof = -1;

 75:   if (!pc->setupcalled) {
 76:     PetscMalloc1(patch->npatch, &patch->solver);
 77:     PCGetOptionsPrefix(pc, &prefix);
 78:     PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
 79:     for (i = 0; i < patch->npatch; ++i) {
 80:       SNES snes;

 82:       SNESCreate(PETSC_COMM_SELF, &snes);
 83:       SNESSetOptionsPrefix(snes, prefix);
 84:       SNESAppendOptionsPrefix(snes, "sub_");
 85:       PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2);
 86:       patch->solver[i] = (PetscObject)snes;

 88:       PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof);
 89:       maxDof = PetscMax(maxDof, dof);
 90:     }
 91:     VecDuplicate(patch->localUpdate, &patch->localState);
 92:     VecDuplicate(patch->patchRHS, &patch->patchResidual);
 93:     VecDuplicate(patch->patchUpdate, &patch->patchState);

 95:     VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll);
 96:     VecSetUp(patch->patchStateWithAll);
 97:   }
 98:   for (i = 0; i < patch->npatch; ++i) {
 99:     SNES snes = (SNES)patch->solver[i];

101:     SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc);
102:     SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
103:   }
104:   if (!pc->setupcalled && patch->optionsSet)
105:     for (i = 0; i < patch->npatch; ++i) SNESSetFromOptions((SNES)patch->solver[i]);
106:   return 0;
107: }

109: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
110: {
111:   PC_PATCH *patch = (PC_PATCH *)pc->data;
112:   PetscInt  pStart, n;

114:   patch->currentPatch = i;
115:   PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);

117:   /* Scatter the overlapped global state to our patch state vector */
118:   PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
119:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
120:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);

122:   MatGetLocalSize(patch->mat[i], NULL, &n);
123:   patch->patchState->map->n = n;
124:   patch->patchState->map->N = n;
125:   patchUpdate->map->n       = n;
126:   patchUpdate->map->N       = n;
127:   patchRHS->map->n          = n;
128:   patchRHS->map->N          = n;
129:   /* Set initial guess to be current state*/
130:   VecCopy(patch->patchState, patchUpdate);
131:   /* Solve for new state */
132:   SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate);
133:   /* To compute update, subtract off previous state */
134:   VecAXPY(patchUpdate, -1.0, patch->patchState);

136:   PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
137:   return 0;
138: }

140: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
141: {
142:   PC_PATCH *patch = (PC_PATCH *)pc->data;
143:   PetscInt  i;

145:   if (patch->solver) {
146:     for (i = 0; i < patch->npatch; ++i) SNESReset((SNES)patch->solver[i]);
147:   }

149:   VecDestroy(&patch->patchResidual);
150:   VecDestroy(&patch->patchState);
151:   VecDestroy(&patch->patchStateWithAll);

153:   VecDestroy(&patch->localState);
154:   return 0;
155: }

157: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
158: {
159:   PC_PATCH *patch = (PC_PATCH *)pc->data;
160:   PetscInt  i;

162:   if (patch->solver) {
163:     for (i = 0; i < patch->npatch; ++i) SNESDestroy((SNES *)&patch->solver[i]);
164:     PetscFree(patch->solver);
165:   }
166:   return 0;
167: }

169: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
170: {
171:   PC_PATCH *patch = (PC_PATCH *)pc->data;

173:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
174:   return 0;
175: }

177: static PetscErrorCode SNESSetUp_Patch(SNES snes)
178: {
179:   SNES_Patch *patch = (SNES_Patch *)snes->data;
180:   DM          dm;
181:   Mat         dummy;
182:   Vec         F;
183:   PetscInt    n, N;

185:   SNESGetDM(snes, &dm);
186:   PCSetDM(patch->pc, dm);
187:   SNESGetFunction(snes, &F, NULL, NULL);
188:   VecGetLocalSize(F, &n);
189:   VecGetSize(F, &N);
190:   MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy);
191:   PCSetOperators(patch->pc, dummy, dummy);
192:   MatDestroy(&dummy);
193:   PCSetUp(patch->pc);
194:   /* allocate workspace */
195:   return 0;
196: }

198: static PetscErrorCode SNESReset_Patch(SNES snes)
199: {
200:   SNES_Patch *patch = (SNES_Patch *)snes->data;

202:   PCReset(patch->pc);
203:   return 0;
204: }

206: static PetscErrorCode SNESDestroy_Patch(SNES snes)
207: {
208:   SNES_Patch *patch = (SNES_Patch *)snes->data;

210:   SNESReset_Patch(snes);
211:   PCDestroy(&patch->pc);
212:   PetscFree(snes->data);
213:   return 0;
214: }

216: static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject)
217: {
218:   SNES_Patch *patch = (SNES_Patch *)snes->data;
219:   const char *prefix;

221:   PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
222:   PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
223:   PCSetFromOptions(patch->pc);
224:   return 0;
225: }

227: static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
228: {
229:   SNES_Patch *patch = (SNES_Patch *)snes->data;
230:   PetscBool   iascii;

232:   PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii);
233:   if (iascii) PetscViewerASCIIPrintf(viewer, "SNESPATCH\n");
234:   PetscViewerASCIIPushTab(viewer);
235:   PCView(patch->pc, viewer);
236:   PetscViewerASCIIPopTab(viewer);
237:   return 0;
238: }

240: static PetscErrorCode SNESSolve_Patch(SNES snes)
241: {
242:   SNES_Patch        *patch   = (SNES_Patch *)snes->data;
243:   PC_PATCH          *pcpatch = (PC_PATCH *)patch->pc->data;
244:   SNESLineSearch     ls;
245:   Vec                rhs, update, state, residual;
246:   const PetscScalar *globalState = NULL;
247:   PetscScalar       *localState  = NULL;
248:   PetscInt           its         = 0;
249:   PetscReal          xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;

251:   SNESGetSolution(snes, &state);
252:   SNESGetSolutionUpdate(snes, &update);
253:   SNESGetRhs(snes, &rhs);

255:   SNESGetFunction(snes, &residual, NULL, NULL);
256:   SNESGetLineSearch(snes, &ls);

258:   SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
259:   VecSet(update, 0.0);
260:   SNESComputeFunction(snes, state, residual);

262:   VecNorm(state, NORM_2, &xnorm);
263:   VecNorm(residual, NORM_2, &fnorm);
264:   snes->ttol = fnorm * snes->rtol;

266:   if (snes->ops->converged) {
267:     PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
268:   } else {
269:     SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL);
270:   }
271:   SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
272:   SNESMonitor(snes, its, fnorm);

274:   /* The main solver loop */
275:   for (its = 0; its < snes->max_its; its++) {
276:     SNESSetIterationNumber(snes, its);

278:     /* Scatter state vector to overlapped vector on all patches.
279:        The vector pcpatch->localState is scattered to each patch
280:        in PCApply_PATCH_Nonlinear. */
281:     VecGetArrayRead(state, &globalState);
282:     VecGetArray(pcpatch->localState, &localState);
283:     PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE);
284:     PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE);
285:     VecRestoreArray(pcpatch->localState, &localState);
286:     VecRestoreArrayRead(state, &globalState);

288:     /* The looping over patches happens here */
289:     PCApply(patch->pc, rhs, update);

291:     /* Apply a line search. This will often be basic with
292:        damping = 1/(max number of patches a dof can be in),
293:        but not always */
294:     VecScale(update, -1.0);
295:     SNESLineSearchApply(ls, state, residual, &fnorm, update);

297:     VecNorm(state, NORM_2, &xnorm);
298:     VecNorm(update, NORM_2, &ynorm);

300:     if (snes->ops->converged) {
301:       PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
302:     } else {
303:       SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL);
304:     }
305:     SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
306:     SNESMonitor(snes, its, fnorm);
307:   }

309:   if (its == snes->max_its) SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT);
310:   return 0;
311: }

313: /*MC
314:   SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches

316:   Level: intermediate

318:    References:
319: .  * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015

321: .seealso: `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
322: M*/
323: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
324: {
325:   SNES_Patch    *patch;
326:   PC_PATCH      *patchpc;
327:   SNESLineSearch linesearch;

329:   PetscNew(&patch);

331:   snes->ops->solve          = SNESSolve_Patch;
332:   snes->ops->setup          = SNESSetUp_Patch;
333:   snes->ops->reset          = SNESReset_Patch;
334:   snes->ops->destroy        = SNESDestroy_Patch;
335:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
336:   snes->ops->view           = SNESView_Patch;

338:   SNESGetLineSearch(snes, &linesearch);
339:   if (!((PetscObject)linesearch)->type_name) SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC);
340:   snes->usesksp = PETSC_FALSE;

342:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

344:   snes->data = (void *)patch;
345:   PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc);
346:   PCSetType(patch->pc, PCPATCH);

348:   patchpc              = (PC_PATCH *)patch->pc->data;
349:   patchpc->classname   = "snes";
350:   patchpc->isNonlinear = PETSC_TRUE;

352:   patchpc->setupsolver          = PCSetUp_PATCH_Nonlinear;
353:   patchpc->applysolver          = PCApply_PATCH_Nonlinear;
354:   patchpc->resetsolver          = PCReset_PATCH_Nonlinear;
355:   patchpc->destroysolver        = PCDestroy_PATCH_Nonlinear;
356:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;

358:   return 0;
359: }

361: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
362: {
363:   SNES_Patch *patch = (SNES_Patch *)snes->data;
364:   DM          dm;

366:   SNESGetDM(snes, &dm);
368:   PCSetDM(patch->pc, dm);
369:   PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
370:   return 0;
371: }

373: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
374: {
375:   SNES_Patch *patch = (SNES_Patch *)snes->data;

377:   PCPatchSetComputeOperator(patch->pc, func, ctx);
378:   return 0;
379: }

381: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
382: {
383:   SNES_Patch *patch = (SNES_Patch *)snes->data;

385:   PCPatchSetComputeFunction(patch->pc, func, ctx);
386:   return 0;
387: }

389: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
390: {
391:   SNES_Patch *patch = (SNES_Patch *)snes->data;

393:   PCPatchSetConstructType(patch->pc, ctype, func, ctx);
394:   return 0;
395: }

397: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
398: {
399:   SNES_Patch *patch = (SNES_Patch *)snes->data;

401:   PCPatchSetCellNumbering(patch->pc, cellNumbering);
402:   return 0;
403: }