Actual source code: submat.c


  2: #include <petsc/private/matimpl.h>

  4: typedef struct {
  5:   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
  6:   Vec        lwork, rwork;   /* work vectors inside the scatters */
  7:   Vec        lwork2, rwork2; /* work vectors inside the scatters */
  8:   VecScatter lrestrict, rprolong;
  9:   Mat        A;
 10: } Mat_SubVirtual;

 12: static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
 13: {
 14:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 16:   MatScale(Na->A, a);
 17:   return 0;
 18: }

 20: static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
 21: {
 22:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 24:   MatShift(Na->A, a);
 25:   return 0;
 26: }

 28: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
 29: {
 30:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 32:   if (right) {
 33:     VecZeroEntries(Na->rwork);
 34:     VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 35:     VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 36:   }
 37:   if (left) {
 38:     VecZeroEntries(Na->lwork);
 39:     VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
 40:     VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
 41:   }
 42:   MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL);
 43:   return 0;
 44: }

 46: static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
 47: {
 48:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 50:   MatGetDiagonal(Na->A, Na->rwork);
 51:   VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE);
 52:   VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE);
 53:   return 0;
 54: }

 56: static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
 57: {
 58:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 60:   VecZeroEntries(Na->rwork);
 61:   VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 62:   VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 63:   MatMult(Na->A, Na->rwork, Na->lwork);
 64:   VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD);
 65:   VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD);
 66:   return 0;
 67: }

 69: static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
 70: {
 71:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 73:   VecZeroEntries(Na->rwork);
 74:   VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 75:   VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
 76:   if (v1 == v2) {
 77:     MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork);
 78:   } else if (v2 == v3) {
 79:     VecZeroEntries(Na->lwork);
 80:     VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
 81:     VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
 82:     MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork);
 83:   } else {
 84:     if (!Na->lwork2) {
 85:       VecDuplicate(Na->lwork, &Na->lwork2);
 86:     } else {
 87:       VecZeroEntries(Na->lwork2);
 88:     }
 89:     VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE);
 90:     VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE);
 91:     MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork);
 92:   }
 93:   VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD);
 94:   VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD);
 95:   return 0;
 96: }

 98: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
 99: {
100:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

102:   VecZeroEntries(Na->lwork);
103:   VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
104:   VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
105:   MatMultTranspose(Na->A, Na->lwork, Na->rwork);
106:   VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE);
107:   VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE);
108:   return 0;
109: }

111: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
112: {
113:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

115:   VecZeroEntries(Na->lwork);
116:   VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
117:   VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
118:   if (v1 == v2) {
119:     MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork);
120:   } else if (v2 == v3) {
121:     VecZeroEntries(Na->rwork);
122:     VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
123:     VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
124:     MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork);
125:   } else {
126:     if (!Na->rwork2) {
127:       VecDuplicate(Na->rwork, &Na->rwork2);
128:     } else {
129:       VecZeroEntries(Na->rwork2);
130:     }
131:     VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD);
132:     VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD);
133:     MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork);
134:   }
135:   VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE);
136:   VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE);
137:   return 0;
138: }

140: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
141: {
142:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

144:   ISDestroy(&Na->isrow);
145:   ISDestroy(&Na->iscol);
146:   VecDestroy(&Na->lwork);
147:   VecDestroy(&Na->rwork);
148:   VecDestroy(&Na->lwork2);
149:   VecDestroy(&Na->rwork2);
150:   VecScatterDestroy(&Na->lrestrict);
151:   VecScatterDestroy(&Na->rprolong);
152:   MatDestroy(&Na->A);
153:   PetscFree(N->data);
154:   return 0;
155: }

157: /*@
158:    MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix

160:    Collective

162:    Input Parameters:
163: +  A - matrix that we will extract a submatrix of
164: .  isrow - rows to be present in the submatrix
165: -  iscol - columns to be present in the submatrix

167:    Output Parameters:
168: .  newmat - new matrix

170:    Level: developer

172:    Note:
173:    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

175:    Developer Note:
176:    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` should likely be changed

178: .seealso: `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
179: @*/
180: PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
181: {
182:   Vec             left, right;
183:   PetscInt        m, n;
184:   Mat             N;
185:   Mat_SubVirtual *Na;

191:   *newmat = NULL;

193:   MatCreate(PetscObjectComm((PetscObject)A), &N);
194:   ISGetLocalSize(isrow, &m);
195:   ISGetLocalSize(iscol, &n);
196:   MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE);
197:   PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX);

199:   PetscNew(&Na);
200:   N->data = (void *)Na;

202:   PetscObjectReference((PetscObject)isrow);
203:   PetscObjectReference((PetscObject)iscol);
204:   Na->isrow = isrow;
205:   Na->iscol = iscol;

207:   PetscFree(N->defaultvectype);
208:   PetscStrallocpy(A->defaultvectype, &N->defaultvectype);
209:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
210:      the reference count of the context. This is a problem if A is already of type MATSHELL */
211:   MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A);

213:   N->ops->destroy          = MatDestroy_SubMatrix;
214:   N->ops->mult             = MatMult_SubMatrix;
215:   N->ops->multadd          = MatMultAdd_SubMatrix;
216:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
217:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
218:   N->ops->scale            = MatScale_SubMatrix;
219:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
220:   N->ops->shift            = MatShift_SubMatrix;
221:   N->ops->convert          = MatConvert_Shell;
222:   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;

224:   MatSetBlockSizesFromMats(N, A, A);
225:   PetscLayoutSetUp(N->rmap);
226:   PetscLayoutSetUp(N->cmap);

228:   MatCreateVecs(A, &Na->rwork, &Na->lwork);
229:   MatCreateVecs(N, &right, &left);
230:   VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict);
231:   VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong);
232:   VecDestroy(&left);
233:   VecDestroy(&right);
234:   MatSetUp(N);

236:   N->assembled = PETSC_TRUE;
237:   *newmat      = N;
238:   return 0;
239: }

241: /*MC
242:    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix

244:   Level: advanced

246:   Developer Note:
247:   This should be named `MATSUBMATRIXVIRTUAL`

249: .seealso: `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
250: M*/

252: /*@
253:    MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix

255:    Collective

257:    Input Parameters:
258: +  N - submatrix to update
259: .  A - full matrix in the submatrix
260: .  isrow - rows in the update (same as the first time the submatrix was created)
261: -  iscol - columns in the update (same as the first time the submatrix was created)

263:    Level: developer

265:    Note:
266:    Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

268: .seealso: MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
269: @*/
270: PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
271: {
272:   PetscBool       flg;
273:   Mat_SubVirtual *Na;

279:   PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg);

282:   Na = (Mat_SubVirtual *)N->data;
283:   ISEqual(isrow, Na->isrow, &flg);
285:   ISEqual(iscol, Na->iscol, &flg);

288:   PetscFree(N->defaultvectype);
289:   PetscStrallocpy(A->defaultvectype, &N->defaultvectype);
290:   MatDestroy(&Na->A);
291:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
292:      the reference count of the context. This is a problem if A is already of type MATSHELL */
293:   MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A);
294:   return 0;
295: }