Actual source code: mattransposematmult.c
2: /*
3: Defines matrix-matrix product routines for
4: C = A^T * B and C = A * B^t
5: with A SeqAIJ and B SeqDense
6: */
8: #include <../src/mat/impls/aij/seq/aij.h>
9: #include <../src/mat/impls/dense/seq/dense.h>
11: PetscErrorCode MatDestroy_SeqDense_MatTransMatMult(void *data)
12: {
13: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
15: MatDestroy(&atb->mA);
16: VecDestroy(&atb->bt);
17: VecDestroy(&atb->ct);
18: PetscFree(atb);
19: return 0;
20: }
22: static PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat, Mat, Mat);
24: PETSC_INTERN PetscErrorCode MatTMatTMultSymbolic_SeqAIJ_SeqDense(Mat A, Mat B, PetscReal fill, Mat C)
25: {
26: Mat_MatTransMatMult *atb;
27: PetscBool cisdense;
28: PetscInt dofm;
30: MatCheckProduct(C, 4);
34: /* create output dense matrix C */
35: if (C->product->type == MATPRODUCT_AtB) {
36: MatSetSizes(C, A->cmap->n, B->cmap->N, A->cmap->n, B->cmap->N);
37: dofm = B->cmap->n;
38: } else {
39: MatSetSizes(C, A->rmap->n, B->rmap->N, A->rmap->n, B->rmap->N);
40: dofm = B->rmap->n;
41: }
42: PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATSEQDENSE, MATSEQDENSECUDA, "");
43: if (!cisdense) MatSetType(C, ((PetscObject)B)->type_name);
44: MatSetUp(C);
46: /* create additional data structure for the product */
47: PetscNew(&atb);
48: MatCreateMAIJ(A, dofm, &atb->mA);
49: MatCreateVecs(atb->mA, &atb->ct, &atb->bt);
50: C->product->data = atb;
51: C->product->destroy = MatDestroy_SeqDense_MatTransMatMult;
53: if (C->product->type == MATPRODUCT_AtB) {
54: C->ops->transposematmultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
55: } else {
56: C->ops->mattransposemultnumeric = MatTMatTMultNumeric_SeqAIJ_SeqDense;
57: }
58: return 0;
59: }
61: PetscErrorCode MatTMatTMultNumeric_SeqAIJ_SeqDense(Mat A, Mat B, Mat C)
62: {
63: PetscInt i, j, m = A->rmap->n, n = A->cmap->n, blda, clda;
64: PetscInt mdof = C->cmap->N;
65: const PetscScalar *Barray;
66: PetscScalar *Carray;
67: Mat_MatTransMatMult *atb;
68: Vec bt, ct;
70: MatCheckProduct(C, 3);
72: atb = (Mat_MatTransMatMult *)C->product->data;
74: bt = atb->bt;
75: ct = atb->ct;
77: MatDenseGetArrayRead(B, &Barray);
78: MatDenseGetLDA(B, &blda);
79: MatDenseGetArrayWrite(C, &Carray);
80: MatDenseGetLDA(C, &clda);
81: if (C->product->type == MATPRODUCT_AtB) { /* transpose local array of B, then copy it to vector bt */
82: const PetscScalar *ctarray;
83: PetscScalar *btarray;
85: VecGetArrayWrite(bt, &btarray);
86: for (j = 0; j < mdof; j++) {
87: for (i = 0; i < m; i++) btarray[i * mdof + j] = Barray[j * blda + i];
88: }
89: VecRestoreArrayWrite(bt, &btarray);
91: /* compute ct = mA^T * cb */
92: MatMultTranspose(atb->mA, bt, ct);
94: /* transpose local array of ct to matrix C */
95: VecGetArrayRead(ct, &ctarray);
96: for (j = 0; j < mdof; j++) {
97: for (i = 0; i < n; i++) Carray[j * clda + i] = ctarray[i * mdof + j];
98: }
99: VecRestoreArrayRead(ct, &ctarray);
100: } else {
101: const PetscScalar *btarray;
102: PetscScalar *ctarray;
104: if (blda == B->rmap->n) {
105: VecPlaceArray(ct, Barray);
106: } else {
107: PetscInt bn = B->cmap->n;
108: PetscInt bm = B->rmap->n;
110: VecGetArrayWrite(ct, &ctarray);
111: for (j = 0; j < bn; j++) {
112: for (i = 0; i < bm; i++) ctarray[j * bm + i] = Barray[j * blda + i];
113: }
114: VecRestoreArrayWrite(ct, &ctarray);
115: }
117: MatMult(atb->mA, ct, bt);
118: if (blda == B->rmap->n) VecResetArray(ct);
119: VecGetArrayRead(bt, &btarray);
120: for (j = 0; j < mdof; j++) {
121: for (i = 0; i < m; i++) Carray[j * clda + i] = btarray[i * mdof + j];
122: }
123: VecRestoreArrayRead(bt, &btarray);
124: }
125: MatDenseRestoreArrayRead(B, &Barray);
126: MatDenseRestoreArray(C, &Carray);
127: return 0;
128: }