Actual source code: mpimatmatmatmult.c

  1: /*
  2:   Defines matrix-matrix-matrix product routines for MPIAIJ matrices
  3:           D = A * B * C
  4: */
  5: #include <../src/mat/impls/aij/mpi/mpiaij.h>

  7: #if defined(PETSC_HAVE_HYPRE)
  8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat, Mat, Mat, PetscReal, Mat);
  9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat, Mat, Mat, Mat);

 11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
 12: {
 13:   Mat_Product *product = RAP->product;
 14:   Mat          Rt, R = product->A, A = product->B, P = product->C;

 16:   MatTransposeGetMat(R, &Rt);
 17:   MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt, A, P, RAP);
 18:   return 0;
 19: }

 21: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
 22: {
 23:   Mat_Product *product = RAP->product;
 24:   Mat          Rt, R = product->A, A = product->B, P = product->C;
 25:   PetscBool    flg;

 27:   /* local sizes of matrices will be checked by the calling subroutines */
 28:   MatTransposeGetMat(R, &Rt);
 29:   PetscObjectTypeCompareAny((PetscObject)Rt, &flg, MATSEQAIJ, MATSEQAIJMKL, MATMPIAIJ, NULL);
 31:   MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt, A, P, product->fill, RAP);
 32:   RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
 33:   return 0;
 34: }

 36: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
 37: {
 38:   Mat_Product *product = C->product;

 40:   if (product->type == MATPRODUCT_ABC) {
 41:     C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
 42:   } else SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_SUP, "MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices", MatProductTypes[product->type]);
 43:   return 0;
 44: }
 45: #endif

 47: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A, Mat B, Mat C, PetscReal fill, Mat D)
 48: {
 49:   Mat          BC;
 50:   PetscBool    scalable;
 51:   Mat_Product *product;

 53:   MatCheckProduct(D, 5);
 55:   product = D->product;
 56:   MatProductCreate(B, C, NULL, &BC);
 57:   MatProductSetType(BC, MATPRODUCT_AB);
 58:   PetscStrcmp(product->alg, "scalable", &scalable);
 59:   if (scalable) {
 60:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(B, C, fill, BC);
 61:     MatZeroEntries(BC); /* initialize value entries of BC */
 62:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(A, BC, fill, D);
 63:   } else {
 64:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B, C, fill, BC);
 65:     MatZeroEntries(BC); /* initialize value entries of BC */
 66:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A, BC, fill, D);
 67:   }
 68:   MatDestroy(&product->Dwork);
 69:   product->Dwork = BC;

 71:   D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
 72:   return 0;
 73: }

 75: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A, Mat B, Mat C, Mat D)
 76: {
 77:   Mat_Product *product;
 78:   Mat          BC;

 80:   MatCheckProduct(D, 4);
 82:   product = D->product;
 83:   BC      = product->Dwork;
 84:   (*BC->ops->matmultnumeric)(B, C, BC);
 85:   (*D->ops->matmultnumeric)(A, BC, D);
 86:   return 0;
 87: }

 89: /* ----------------------------------------------------- */
 90: PetscErrorCode MatDestroy_MPIAIJ_RARt(void *data)
 91: {
 92:   Mat_RARt *rart = (Mat_RARt *)data;

 94:   MatDestroy(&rart->Rt);
 95:   if (rart->destroy) (*rart->destroy)(rart->data);
 96:   PetscFree(rart);
 97:   return 0;
 98: }

100: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
101: {
102:   Mat_RARt *rart;
103:   Mat       A, R, Rt;

105:   MatCheckProduct(C, 1);
107:   rart = (Mat_RARt *)C->product->data;
108:   A    = C->product->A;
109:   R    = C->product->B;
110:   Rt   = rart->Rt;
111:   MatTranspose(R, MAT_REUSE_MATRIX, &Rt);
112:   if (rart->data) C->product->data = rart->data;
113:   (*C->ops->matmatmultnumeric)(R, A, Rt, C);
114:   C->product->data = rart;
115:   return 0;
116: }

118: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
119: {
120:   Mat       A, R, Rt;
121:   Mat_RARt *rart;

123:   MatCheckProduct(C, 1);
125:   A = C->product->A;
126:   R = C->product->B;
127:   MatTranspose(R, MAT_INITIAL_MATRIX, &Rt);
128:   /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
129:   MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R, A, Rt, C->product->fill, C);
130:   C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;

132:   /* create a supporting struct */
133:   PetscNew(&rart);
134:   rart->Rt            = Rt;
135:   rart->data          = C->product->data;
136:   rart->destroy       = C->product->destroy;
137:   C->product->data    = rart;
138:   C->product->destroy = MatDestroy_MPIAIJ_RARt;
139:   return 0;
140: }