Actual source code: transm.c


  2: #include <petsc/private/matimpl.h>

  4: typedef struct {
  5:   Mat A;
  6: } Mat_Transpose;

  8: PetscErrorCode MatMult_Transpose(Mat N, Vec x, Vec y)
  9: {
 10:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 12:   MatMultTranspose(Na->A, x, y);
 13:   return 0;
 14: }

 16: PetscErrorCode MatMultAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3)
 17: {
 18:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 20:   MatMultTransposeAdd(Na->A, v1, v2, v3);
 21:   return 0;
 22: }

 24: PetscErrorCode MatMultTranspose_Transpose(Mat N, Vec x, Vec y)
 25: {
 26:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 28:   MatMult(Na->A, x, y);
 29:   return 0;
 30: }

 32: PetscErrorCode MatMultTransposeAdd_Transpose(Mat N, Vec v1, Vec v2, Vec v3)
 33: {
 34:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 36:   MatMultAdd(Na->A, v1, v2, v3);
 37:   return 0;
 38: }

 40: PetscErrorCode MatDestroy_Transpose(Mat N)
 41: {
 42:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 44:   MatDestroy(&Na->A);
 45:   PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL);
 46:   PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL);
 47:   PetscFree(N->data);
 48:   return 0;
 49: }

 51: PetscErrorCode MatDuplicate_Transpose(Mat N, MatDuplicateOption op, Mat *m)
 52: {
 53:   Mat_Transpose *Na = (Mat_Transpose *)N->data;

 55:   if (op == MAT_COPY_VALUES) {
 56:     MatTranspose(Na->A, MAT_INITIAL_MATRIX, m);
 57:   } else if (op == MAT_DO_NOT_COPY_VALUES) {
 58:     MatDuplicate(Na->A, MAT_DO_NOT_COPY_VALUES, m);
 59:     MatTranspose(*m, MAT_INPLACE_MATRIX, m);
 60:   } else SETERRQ(PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "MAT_SHARE_NONZERO_PATTERN not supported for this matrix type");
 61:   return 0;
 62: }

 64: PetscErrorCode MatCreateVecs_Transpose(Mat A, Vec *r, Vec *l)
 65: {
 66:   Mat_Transpose *Aa = (Mat_Transpose *)A->data;

 68:   MatCreateVecs(Aa->A, l, r);
 69:   return 0;
 70: }

 72: PetscErrorCode MatAXPY_Transpose(Mat Y, PetscScalar a, Mat X, MatStructure str)
 73: {
 74:   Mat_Transpose *Ya = (Mat_Transpose *)Y->data;
 75:   Mat_Transpose *Xa = (Mat_Transpose *)X->data;
 76:   Mat            M  = Ya->A;
 77:   Mat            N  = Xa->A;

 79:   MatAXPY(M, a, N, str);
 80:   return 0;
 81: }

 83: PetscErrorCode MatHasOperation_Transpose(Mat mat, MatOperation op, PetscBool *has)
 84: {
 85:   Mat_Transpose *X = (Mat_Transpose *)mat->data;

 87:   *has = PETSC_FALSE;
 88:   if (op == MATOP_MULT) {
 89:     MatHasOperation(X->A, MATOP_MULT_TRANSPOSE, has);
 90:   } else if (op == MATOP_MULT_TRANSPOSE) {
 91:     MatHasOperation(X->A, MATOP_MULT, has);
 92:   } else if (op == MATOP_MULT_ADD) {
 93:     MatHasOperation(X->A, MATOP_MULT_TRANSPOSE_ADD, has);
 94:   } else if (op == MATOP_MULT_TRANSPOSE_ADD) {
 95:     MatHasOperation(X->A, MATOP_MULT_ADD, has);
 96:   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
 97:   return 0;
 98: }

100: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat D)
101: {
102:   Mat            A, B, C, Ain, Bin, Cin;
103:   PetscBool      Aistrans, Bistrans, Cistrans;
104:   PetscInt       Atrans, Btrans, Ctrans;
105:   MatProductType ptype;

107:   MatCheckProduct(D, 1);
108:   A = D->product->A;
109:   B = D->product->B;
110:   C = D->product->C;
111:   PetscObjectTypeCompare((PetscObject)A, MATTRANSPOSEVIRTUAL, &Aistrans);
112:   PetscObjectTypeCompare((PetscObject)B, MATTRANSPOSEVIRTUAL, &Bistrans);
113:   PetscObjectTypeCompare((PetscObject)C, MATTRANSPOSEVIRTUAL, &Cistrans);
115:   Atrans = 0;
116:   Ain    = A;
117:   while (Aistrans) {
118:     Atrans++;
119:     MatTransposeGetMat(Ain, &Ain);
120:     PetscObjectTypeCompare((PetscObject)Ain, MATTRANSPOSEVIRTUAL, &Aistrans);
121:   }
122:   Btrans = 0;
123:   Bin    = B;
124:   while (Bistrans) {
125:     Btrans++;
126:     MatTransposeGetMat(Bin, &Bin);
127:     PetscObjectTypeCompare((PetscObject)Bin, MATTRANSPOSEVIRTUAL, &Bistrans);
128:   }
129:   Ctrans = 0;
130:   Cin    = C;
131:   while (Cistrans) {
132:     Ctrans++;
133:     MatTransposeGetMat(Cin, &Cin);
134:     PetscObjectTypeCompare((PetscObject)Cin, MATTRANSPOSEVIRTUAL, &Cistrans);
135:   }
136:   Atrans = Atrans % 2;
137:   Btrans = Btrans % 2;
138:   Ctrans = Ctrans % 2;
139:   ptype  = D->product->type; /* same product type by default */
140:   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
141:   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
142:   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;

144:   if (Atrans || Btrans || Ctrans) {
145:     ptype = MATPRODUCT_UNSPECIFIED;
146:     switch (D->product->type) {
147:     case MATPRODUCT_AB:
148:       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
149:         /* TODO custom implementation ? */
150:       } else if (Atrans) { /* At * B */
151:         ptype = MATPRODUCT_AtB;
152:       } else { /* A * Bt */
153:         ptype = MATPRODUCT_ABt;
154:       }
155:       break;
156:     case MATPRODUCT_AtB:
157:       if (Atrans && Btrans) { /* A * Bt */
158:         ptype = MATPRODUCT_ABt;
159:       } else if (Atrans) { /* A * B */
160:         ptype = MATPRODUCT_AB;
161:       } else { /* At * Bt we do not have support for this */
162:         /* TODO custom implementation ? */
163:       }
164:       break;
165:     case MATPRODUCT_ABt:
166:       if (Atrans && Btrans) { /* At * B */
167:         ptype = MATPRODUCT_AtB;
168:       } else if (Atrans) { /* At * Bt we do not have support for this */
169:         /* TODO custom implementation ? */
170:       } else { /* A * B */
171:         ptype = MATPRODUCT_AB;
172:       }
173:       break;
174:     case MATPRODUCT_PtAP:
175:       if (Atrans) { /* PtAtP */
176:         /* TODO custom implementation ? */
177:       } else { /* RARt */
178:         ptype = MATPRODUCT_RARt;
179:       }
180:       break;
181:     case MATPRODUCT_RARt:
182:       if (Atrans) { /* RAtRt */
183:         /* TODO custom implementation ? */
184:       } else { /* PtAP */
185:         ptype = MATPRODUCT_PtAP;
186:       }
187:       break;
188:     case MATPRODUCT_ABC:
189:       /* TODO custom implementation ? */
190:       break;
191:     default:
192:       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
193:     }
194:   }
195:   MatProductReplaceMats(Ain, Bin, Cin, D);
196:   MatProductSetType(D, ptype);
197:   MatProductSetFromOptions(D);
198:   return 0;
199: }

201: PetscErrorCode MatGetDiagonal_Transpose(Mat A, Vec v)
202: {
203:   Mat_Transpose *Aa = (Mat_Transpose *)A->data;

205:   MatGetDiagonal(Aa->A, v);
206:   return 0;
207: }

209: PetscErrorCode MatConvert_Transpose(Mat A, MatType newtype, MatReuse reuse, Mat *newmat)
210: {
211:   Mat_Transpose *Aa = (Mat_Transpose *)A->data;
212:   PetscBool      flg;

214:   MatHasOperation(Aa->A, MATOP_TRANSPOSE, &flg);
215:   if (flg) {
216:     Mat B;

218:     MatTranspose(Aa->A, MAT_INITIAL_MATRIX, &B);
219:     if (reuse != MAT_INPLACE_MATRIX) {
220:       MatConvert(B, newtype, reuse, newmat);
221:       MatDestroy(&B);
222:     } else {
223:       MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B);
224:       MatHeaderReplace(A, &B);
225:     }
226:   } else { /* use basic converter as fallback */
227:     MatConvert_Basic(A, newtype, reuse, newmat);
228:   }
229:   return 0;
230: }

232: PetscErrorCode MatTransposeGetMat_Transpose(Mat A, Mat *M)
233: {
234:   Mat_Transpose *Aa = (Mat_Transpose *)A->data;

236:   *M = Aa->A;
237:   return 0;
238: }

240: /*@
241:       MatTransposeGetMat - Gets the `Mat` object stored inside a `MATTRANSPOSEVIRTUAL`

243:    Logically collective

245:    Input Parameter:
246: .   A  - the `MATTRANSPOSEVIRTUAL` matrix

248:    Output Parameter:
249: .   M - the matrix object stored inside A

251:    Level: intermediate

253: .seealso: `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`
254: @*/
255: PetscErrorCode MatTransposeGetMat(Mat A, Mat *M)
256: {
260:   PetscUseMethod(A, "MatTransposeGetMat_C", (Mat, Mat *), (A, M));
261:   return 0;
262: }

264: /*MC
265:    MATTRANSPOSEVIRTUAL - "transpose" - A matrix type that represents a virtual transpose of a matrix

267:   Level: advanced

269: .seealso: `MATHERMITIANTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
270: M*/

272: /*@
273:       MatCreateTranspose - Creates a new matrix `MATTRANSPOSEVIRTUAL` object that behaves like A'

275:    Collective

277:    Input Parameter:
278: .   A  - the (possibly rectangular) matrix

280:    Output Parameter:
281: .   N - the matrix that represents A'

283:    Level: intermediate

285:    Note:
286:     The transpose A' is NOT actually formed! Rather the new matrix
287:           object performs the matrix-vector product by using the `MatMultTranspose()` on
288:           the original matrix

290: .seealso: `MATTRANSPOSEVIRTUAL`, `MatCreateNormal()`, `MatMult()`, `MatMultTranspose()`, `MatCreate()`
291: @*/
292: PetscErrorCode MatCreateTranspose(Mat A, Mat *N)
293: {
294:   PetscInt       m, n;
295:   Mat_Transpose *Na;
296:   VecType        vtype;

298:   MatGetLocalSize(A, &m, &n);
299:   MatCreate(PetscObjectComm((PetscObject)A), N);
300:   MatSetSizes(*N, n, m, PETSC_DECIDE, PETSC_DECIDE);
301:   PetscLayoutSetUp((*N)->rmap);
302:   PetscLayoutSetUp((*N)->cmap);
303:   PetscObjectChangeTypeName((PetscObject)*N, MATTRANSPOSEVIRTUAL);

305:   PetscNew(&Na);
306:   (*N)->data = (void *)Na;
307:   PetscObjectReference((PetscObject)A);
308:   Na->A = A;

310:   (*N)->ops->destroy               = MatDestroy_Transpose;
311:   (*N)->ops->mult                  = MatMult_Transpose;
312:   (*N)->ops->multadd               = MatMultAdd_Transpose;
313:   (*N)->ops->multtranspose         = MatMultTranspose_Transpose;
314:   (*N)->ops->multtransposeadd      = MatMultTransposeAdd_Transpose;
315:   (*N)->ops->duplicate             = MatDuplicate_Transpose;
316:   (*N)->ops->getvecs               = MatCreateVecs_Transpose;
317:   (*N)->ops->axpy                  = MatAXPY_Transpose;
318:   (*N)->ops->hasoperation          = MatHasOperation_Transpose;
319:   (*N)->ops->productsetfromoptions = MatProductSetFromOptions_Transpose;
320:   (*N)->ops->getdiagonal           = MatGetDiagonal_Transpose;
321:   (*N)->ops->convert               = MatConvert_Transpose;
322:   (*N)->assembled                  = PETSC_TRUE;

324:   PetscObjectComposeFunction((PetscObject)(*N), "MatTransposeGetMat_C", MatTransposeGetMat_Transpose);
325:   PetscObjectComposeFunction((PetscObject)(*N), "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_Transpose);
326:   MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs));
327:   MatGetVecType(A, &vtype);
328:   MatSetVecType(*N, vtype);
329: #if defined(PETSC_HAVE_DEVICE)
330:   MatBindToCPU(*N, A->boundtocpu);
331: #endif
332:   MatSetUp(*N);
333:   return 0;
334: }