Actual source code: mpimattransposematmult.c


  2: /*
  3:   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
  4:           C = A^T * B
  5:   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
  6: */
  7: #include <../src/mat/impls/aij/seq/aij.h>
  8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  9: #include <../src/mat/impls/dense/mpi/mpidense.h>

 11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
 12: {
 13:   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;

 15:   MatDestroy(&atb->mA);
 16:   VecDestroy(&atb->bt);
 17:   VecDestroy(&atb->ct);
 18:   PetscFree(atb);
 19:   return 0;
 20: }

 22: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);

 24: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
 25: {
 26:   Mat_MatTransMatMult *atb;
 27:   PetscBool            cisdense;

 29:   MatCheckProduct(C, 4);

 32:   /* create output dense matrix C = A^T*B */
 33:   MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N);
 34:   PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, "");
 35:   if (!cisdense) MatSetType(C, ((PetscObject)B)->type_name);
 36:   MatSetUp(C);

 38:   /* create additional data structure for the product */
 39:   PetscNew(&atb);
 40:   if (B->cmap->N) {
 41:     MatCreateMAIJ(A, B->cmap->N, &atb->mA);
 42:     if (!atb->mA->assembled) {
 43:       MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY);
 44:       MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY);
 45:     }
 46:     MatCreateVecs(atb->mA, &atb->ct, &atb->bt);
 47:   }
 48:   C->product->data    = atb;
 49:   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;

 51:   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
 52:   return 0;
 53: }

 55: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
 56: {
 57:   const PetscScalar   *Barray, *ctarray;
 58:   PetscScalar         *Carray, *btarray;
 59:   PetscInt             i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
 60:   Mat_MatTransMatMult *atb;
 61:   Vec                  bt, ct;

 63:   MatCheckProduct(C, 3);
 64:   atb = (Mat_MatTransMatMult *)C->product->data;
 66:   if (!BN) {
 67:     MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY);
 68:     MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY);
 69:     return 0;
 70:   }
 71:   bt = atb->bt;
 72:   ct = atb->ct;

 74:   /* transpose local array of B, then copy it to vector bt */
 75:   MatDenseGetArrayRead(B, &Barray);
 76:   MatDenseGetLDA(B, &ldb);
 77:   VecGetArray(bt, &btarray);
 78:   for (j = 0; j < BN; j++)
 79:     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
 80:   VecRestoreArray(bt, &btarray);
 81:   MatDenseRestoreArrayRead(B, &Barray);

 83:   /* compute ct = mA^T * cb */
 84:   MatMultTranspose(atb->mA, bt, ct);

 86:   /* transpose local array of ct to matrix C */
 87:   MatDenseGetArray(C, &Carray);
 88:   MatDenseGetLDA(C, &ldc);
 89:   VecGetArrayRead(ct, &ctarray);
 90:   for (j = 0; j < BN; j++)
 91:     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
 92:   VecRestoreArrayRead(ct, &ctarray);
 93:   MatDenseRestoreArray(C, &Carray);
 94:   MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY);
 95:   MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY);
 96:   return 0;
 97: }