Actual source code: submat.c
2: #include <petsc/private/matimpl.h>
4: typedef struct {
5: IS isrow, iscol; /* rows and columns in submatrix, only used to check consistency */
6: Vec lwork, rwork; /* work vectors inside the scatters */
7: Vec lwork2, rwork2; /* work vectors inside the scatters */
8: VecScatter lrestrict, rprolong;
9: Mat A;
10: } Mat_SubVirtual;
12: static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
13: {
14: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
16: MatScale(Na->A, a);
17: return 0;
18: }
20: static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
21: {
22: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
24: MatShift(Na->A, a);
25: return 0;
26: }
28: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
29: {
30: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
32: if (right) {
33: VecZeroEntries(Na->rwork);
34: VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
35: VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
36: }
37: if (left) {
38: VecZeroEntries(Na->lwork);
39: VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
40: VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
41: }
42: MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL);
43: return 0;
44: }
46: static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
47: {
48: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
50: MatGetDiagonal(Na->A, Na->rwork);
51: VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE);
52: VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE);
53: return 0;
54: }
56: static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
57: {
58: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
60: VecZeroEntries(Na->rwork);
61: VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
62: VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
63: MatMult(Na->A, Na->rwork, Na->lwork);
64: VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD);
65: VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD);
66: return 0;
67: }
69: static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
70: {
71: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
73: VecZeroEntries(Na->rwork);
74: VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
75: VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
76: if (v1 == v2) {
77: MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork);
78: } else if (v2 == v3) {
79: VecZeroEntries(Na->lwork);
80: VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
81: VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
82: MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork);
83: } else {
84: if (!Na->lwork2) {
85: VecDuplicate(Na->lwork, &Na->lwork2);
86: } else {
87: VecZeroEntries(Na->lwork2);
88: }
89: VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE);
90: VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE);
91: MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork);
92: }
93: VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD);
94: VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD);
95: return 0;
96: }
98: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
99: {
100: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
102: VecZeroEntries(Na->lwork);
103: VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
104: VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
105: MatMultTranspose(Na->A, Na->lwork, Na->rwork);
106: VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE);
107: VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE);
108: return 0;
109: }
111: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
112: {
113: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
115: VecZeroEntries(Na->lwork);
116: VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
117: VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE);
118: if (v1 == v2) {
119: MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork);
120: } else if (v2 == v3) {
121: VecZeroEntries(Na->rwork);
122: VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
123: VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD);
124: MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork);
125: } else {
126: if (!Na->rwork2) {
127: VecDuplicate(Na->rwork, &Na->rwork2);
128: } else {
129: VecZeroEntries(Na->rwork2);
130: }
131: VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD);
132: VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD);
133: MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork);
134: }
135: VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE);
136: VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE);
137: return 0;
138: }
140: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
141: {
142: Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;
144: ISDestroy(&Na->isrow);
145: ISDestroy(&Na->iscol);
146: VecDestroy(&Na->lwork);
147: VecDestroy(&Na->rwork);
148: VecDestroy(&Na->lwork2);
149: VecDestroy(&Na->rwork2);
150: VecScatterDestroy(&Na->lrestrict);
151: VecScatterDestroy(&Na->rprolong);
152: MatDestroy(&Na->A);
153: PetscFree(N->data);
154: return 0;
155: }
157: /*@
158: MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix
160: Collective
162: Input Parameters:
163: + A - matrix that we will extract a submatrix of
164: . isrow - rows to be present in the submatrix
165: - iscol - columns to be present in the submatrix
167: Output Parameters:
168: . newmat - new matrix
170: Level: developer
172: Note:
173: Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
175: Developer Note:
176: The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` should likely be changed
178: .seealso: `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
179: @*/
180: PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
181: {
182: Vec left, right;
183: PetscInt m, n;
184: Mat N;
185: Mat_SubVirtual *Na;
191: *newmat = NULL;
193: MatCreate(PetscObjectComm((PetscObject)A), &N);
194: ISGetLocalSize(isrow, &m);
195: ISGetLocalSize(iscol, &n);
196: MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE);
197: PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX);
199: PetscNew(&Na);
200: N->data = (void *)Na;
202: PetscObjectReference((PetscObject)isrow);
203: PetscObjectReference((PetscObject)iscol);
204: Na->isrow = isrow;
205: Na->iscol = iscol;
207: PetscFree(N->defaultvectype);
208: PetscStrallocpy(A->defaultvectype, &N->defaultvectype);
209: /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
210: the reference count of the context. This is a problem if A is already of type MATSHELL */
211: MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A);
213: N->ops->destroy = MatDestroy_SubMatrix;
214: N->ops->mult = MatMult_SubMatrix;
215: N->ops->multadd = MatMultAdd_SubMatrix;
216: N->ops->multtranspose = MatMultTranspose_SubMatrix;
217: N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
218: N->ops->scale = MatScale_SubMatrix;
219: N->ops->diagonalscale = MatDiagonalScale_SubMatrix;
220: N->ops->shift = MatShift_SubMatrix;
221: N->ops->convert = MatConvert_Shell;
222: N->ops->getdiagonal = MatGetDiagonal_SubMatrix;
224: MatSetBlockSizesFromMats(N, A, A);
225: PetscLayoutSetUp(N->rmap);
226: PetscLayoutSetUp(N->cmap);
228: MatCreateVecs(A, &Na->rwork, &Na->lwork);
229: MatCreateVecs(N, &right, &left);
230: VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict);
231: VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong);
232: VecDestroy(&left);
233: VecDestroy(&right);
234: MatSetUp(N);
236: N->assembled = PETSC_TRUE;
237: *newmat = N;
238: return 0;
239: }
241: /*MC
242: MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix
244: Level: advanced
246: Developer Note:
247: This should be named `MATSUBMATRIXVIRTUAL`
249: .seealso: `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
250: M*/
252: /*@
253: MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix
255: Collective
257: Input Parameters:
258: + N - submatrix to update
259: . A - full matrix in the submatrix
260: . isrow - rows in the update (same as the first time the submatrix was created)
261: - iscol - columns in the update (same as the first time the submatrix was created)
263: Level: developer
265: Note:
266: Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.
268: .seealso: MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
269: @*/
270: PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
271: {
272: PetscBool flg;
273: Mat_SubVirtual *Na;
279: PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg);
282: Na = (Mat_SubVirtual *)N->data;
283: ISEqual(isrow, Na->isrow, &flg);
285: ISEqual(iscol, Na->iscol, &flg);
288: PetscFree(N->defaultvectype);
289: PetscStrallocpy(A->defaultvectype, &N->defaultvectype);
290: MatDestroy(&Na->A);
291: /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
292: the reference count of the context. This is a problem if A is already of type MATSHELL */
293: MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A);
294: return 0;
295: }