Actual source code: mpimattransposematmult.c
2: /*
3: Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4: C = A^T * B
5: The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6: */
7: #include <../src/mat/impls/aij/seq/aij.h>
8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
9: #include <../src/mat/impls/dense/mpi/mpidense.h>
11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
12: {
13: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
15: MatDestroy(&atb->mA);
16: VecDestroy(&atb->bt);
17: VecDestroy(&atb->ct);
18: PetscFree(atb);
19: return 0;
20: }
22: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
24: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
25: {
26: Mat_MatTransMatMult *atb;
27: PetscBool cisdense;
29: MatCheckProduct(C, 4);
32: /* create output dense matrix C = A^T*B */
33: MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N);
34: PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "");
35: if (!cisdense) MatSetType(C, ((PetscObject)B)->type_name);
36: MatSetUp(C);
38: /* create additional data structure for the product */
39: PetscNew(&atb);
40: if (B->cmap->N) {
41: MatCreateMAIJ(A, B->cmap->N, &atb->mA);
42: if (!atb->mA->assembled) {
43: MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY);
44: MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY);
45: }
46: MatCreateVecs(atb->mA, &atb->ct, &atb->bt);
47: }
48: C->product->data = atb;
49: C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
51: C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
52: return 0;
53: }
55: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
56: {
57: const PetscScalar *Barray, *ctarray;
58: PetscScalar *Carray, *btarray;
59: PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
60: Mat_MatTransMatMult *atb;
61: Vec bt, ct;
63: MatCheckProduct(C, 3);
64: atb = (Mat_MatTransMatMult *)C->product->data;
66: if (!BN) {
67: MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY);
68: MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY);
69: return 0;
70: }
71: bt = atb->bt;
72: ct = atb->ct;
74: /* transpose local array of B, then copy it to vector bt */
75: MatDenseGetArrayRead(B, &Barray);
76: MatDenseGetLDA(B, &ldb);
77: VecGetArray(bt, &btarray);
78: for (j = 0; j < BN; j++)
79: for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
80: VecRestoreArray(bt, &btarray);
81: MatDenseRestoreArrayRead(B, &Barray);
83: /* compute ct = mA^T * cb */
84: MatMultTranspose(atb->mA, bt, ct);
86: /* transpose local array of ct to matrix C */
87: MatDenseGetArray(C, &Carray);
88: MatDenseGetLDA(C, &ldc);
89: VecGetArrayRead(ct, &ctarray);
90: for (j = 0; j < BN; j++)
91: for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
92: VecRestoreArrayRead(ct, &ctarray);
93: MatDenseRestoreArray(C, &Carray);
94: MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY);
95: MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY);
96: return 0;
97: }