Actual source code: mpimatmatmatmult.c
1: /*
2: Defines matrix-matrix-matrix product routines for MPIAIJ matrices
3: D = A * B * C
4: */
5: #include <../src/mat/impls/aij/mpi/mpiaij.h>
7: #if defined(PETSC_HAVE_HYPRE)
8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat, Mat, Mat, PetscReal, Mat);
9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat, Mat, Mat, Mat);
11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
12: {
13: Mat_Product *product = RAP->product;
14: Mat Rt, R = product->A, A = product->B, P = product->C;
16: MatTransposeGetMat(R, &Rt);
17: MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt, A, P, RAP);
18: return 0;
19: }
21: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
22: {
23: Mat_Product *product = RAP->product;
24: Mat Rt, R = product->A, A = product->B, P = product->C;
25: PetscBool flg;
27: /* local sizes of matrices will be checked by the calling subroutines */
28: MatTransposeGetMat(R, &Rt);
29: PetscObjectTypeCompareAny((PetscObject)Rt, &flg, MATSEQAIJ, MATSEQAIJMKL, MATMPIAIJ, NULL);
31: MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt, A, P, product->fill, RAP);
32: RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
33: return 0;
34: }
36: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
37: {
38: Mat_Product *product = C->product;
40: if (product->type == MATPRODUCT_ABC) {
41: C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
42: } else SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_SUP, "MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices", MatProductTypes[product->type]);
43: return 0;
44: }
45: #endif
47: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A, Mat B, Mat C, PetscReal fill, Mat D)
48: {
49: Mat BC;
50: PetscBool scalable;
51: Mat_Product *product;
53: MatCheckProduct(D, 5);
55: product = D->product;
56: MatProductCreate(B, C, NULL, &BC);
57: MatProductSetType(BC, MATPRODUCT_AB);
58: PetscStrcmp(product->alg, "scalable", &scalable);
59: if (scalable) {
60: MatMatMultSymbolic_MPIAIJ_MPIAIJ(B, C, fill, BC);
61: MatZeroEntries(BC); /* initialize value entries of BC */
62: MatMatMultSymbolic_MPIAIJ_MPIAIJ(A, BC, fill, D);
63: } else {
64: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B, C, fill, BC);
65: MatZeroEntries(BC); /* initialize value entries of BC */
66: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A, BC, fill, D);
67: }
68: MatDestroy(&product->Dwork);
69: product->Dwork = BC;
71: D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
72: return 0;
73: }
75: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A, Mat B, Mat C, Mat D)
76: {
77: Mat_Product *product;
78: Mat BC;
80: MatCheckProduct(D, 4);
82: product = D->product;
83: BC = product->Dwork;
84: (*BC->ops->matmultnumeric)(B, C, BC);
85: (*D->ops->matmultnumeric)(A, BC, D);
86: return 0;
87: }
89: /* ----------------------------------------------------- */
90: PetscErrorCode MatDestroy_MPIAIJ_RARt(void *data)
91: {
92: Mat_RARt *rart = (Mat_RARt *)data;
94: MatDestroy(&rart->Rt);
95: if (rart->destroy) (*rart->destroy)(rart->data);
96: PetscFree(rart);
97: return 0;
98: }
100: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
101: {
102: Mat_RARt *rart;
103: Mat A, R, Rt;
105: MatCheckProduct(C, 1);
107: rart = (Mat_RARt *)C->product->data;
108: A = C->product->A;
109: R = C->product->B;
110: Rt = rart->Rt;
111: MatTranspose(R, MAT_REUSE_MATRIX, &Rt);
112: if (rart->data) C->product->data = rart->data;
113: (*C->ops->matmatmultnumeric)(R, A, Rt, C);
114: C->product->data = rart;
115: return 0;
116: }
118: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
119: {
120: Mat A, R, Rt;
121: Mat_RARt *rart;
123: MatCheckProduct(C, 1);
125: A = C->product->A;
126: R = C->product->B;
127: MatTranspose(R, MAT_INITIAL_MATRIX, &Rt);
128: /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
129: MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R, A, Rt, C->product->fill, C);
130: C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;
132: /* create a supporting struct */
133: PetscNew(&rart);
134: rart->Rt = Rt;
135: rart->data = C->product->data;
136: rart->destroy = C->product->destroy;
137: C->product->data = rart;
138: C->product->destroy = MatDestroy_MPIAIJ_RARt;
139: return 0;
140: }