Actual source code: hipvecimpl.h

  1: #ifndef PETSC_HIPVECIMPL_H
  2: #define PETSC_HIPVECIMPL_H

  4: #include <petscvec.h>
  5: #include <petscdevice_hip.h>
  6: #include <petsc/private/deviceimpl.h>
  7: #include <petsc/private/vecimpl.h>

  9: typedef struct {
 10:   PetscScalar *GPUarray;           /* this always holds the GPU data */
 11:   PetscScalar *GPUarray_allocated; /* if the array was allocated by PETSc this is its pointer */
 12:   hipStream_t  stream;             /* A stream for doing asynchronous data transfers */
 13: } Vec_HIP;

 15: PETSC_INTERN PetscErrorCode VecHIPGetArrays_Private(Vec, const PetscScalar **, const PetscScalar **, PetscOffloadMask *);
 16: PETSC_INTERN PetscErrorCode VecDotNorm2_SeqHIP(Vec, Vec, PetscScalar *, PetscScalar *);
 17: PETSC_INTERN PetscErrorCode VecPointwiseDivide_SeqHIP(Vec, Vec, Vec);
 18: PETSC_INTERN PetscErrorCode VecWAXPY_SeqHIP(Vec, PetscScalar, Vec, Vec);
 19: PETSC_INTERN PetscErrorCode VecMDot_SeqHIP(Vec, PetscInt, const Vec[], PetscScalar *);
 20: PETSC_EXTERN PetscErrorCode VecSet_SeqHIP(Vec, PetscScalar);
 21: PETSC_INTERN PetscErrorCode VecMAXPY_SeqHIP(Vec, PetscInt, const PetscScalar *, Vec *);
 22: PETSC_INTERN PetscErrorCode VecAXPBYPCZ_SeqHIP(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec);
 23: PETSC_INTERN PetscErrorCode VecPointwiseMult_SeqHIP(Vec, Vec, Vec);
 24: PETSC_INTERN PetscErrorCode VecPlaceArray_SeqHIP(Vec, const PetscScalar *);
 25: PETSC_INTERN PetscErrorCode VecResetArray_SeqHIP(Vec);
 26: PETSC_INTERN PetscErrorCode VecReplaceArray_SeqHIP(Vec, const PetscScalar *);
 27: PETSC_INTERN PetscErrorCode VecDot_SeqHIP(Vec, Vec, PetscScalar *);
 28: PETSC_INTERN PetscErrorCode VecTDot_SeqHIP(Vec, Vec, PetscScalar *);
 29: PETSC_INTERN PetscErrorCode VecScale_SeqHIP(Vec, PetscScalar);
 30: PETSC_EXTERN PetscErrorCode VecCopy_SeqHIP(Vec, Vec);
 31: PETSC_INTERN PetscErrorCode VecSwap_SeqHIP(Vec, Vec);
 32: PETSC_EXTERN PetscErrorCode VecAXPY_SeqHIP(Vec, PetscScalar, Vec);
 33: PETSC_INTERN PetscErrorCode VecAXPBY_SeqHIP(Vec, PetscScalar, PetscScalar, Vec);
 34: PETSC_INTERN PetscErrorCode VecDuplicate_SeqHIP(Vec, Vec *);
 35: PETSC_INTERN PetscErrorCode VecConjugate_SeqHIP(Vec xin);
 36: PETSC_INTERN PetscErrorCode VecNorm_SeqHIP(Vec, NormType, PetscReal *);
 37: PETSC_INTERN PetscErrorCode VecHIPCopyToGPU(Vec);
 38: PETSC_INTERN PetscErrorCode VecHIPAllocateCheck(Vec);
 39: PETSC_EXTERN PetscErrorCode VecCreate_SeqHIP(Vec);
 40: PETSC_INTERN PetscErrorCode VecCreate_SeqHIP_Private(Vec, const PetscScalar *);
 41: PETSC_INTERN PetscErrorCode VecCreate_MPIHIP(Vec);
 42: PETSC_INTERN PetscErrorCode VecCreate_MPIHIP_Private(Vec, PetscBool, PetscInt, const PetscScalar *);
 43: PETSC_INTERN PetscErrorCode VecCreate_HIP(Vec);
 44: PETSC_INTERN PetscErrorCode VecDestroy_SeqHIP(Vec);
 45: PETSC_INTERN PetscErrorCode VecDestroy_MPIHIP(Vec);
 46: PETSC_INTERN PetscErrorCode VecAYPX_SeqHIP(Vec, PetscScalar, Vec);
 47: PETSC_INTERN PetscErrorCode VecSetRandom_SeqHIP(Vec, PetscRandom);
 48: PETSC_INTERN PetscErrorCode VecGetLocalVector_SeqHIP(Vec, Vec);
 49: PETSC_INTERN PetscErrorCode VecRestoreLocalVector_SeqHIP(Vec, Vec);
 50: PETSC_INTERN PetscErrorCode VecGetLocalVectorRead_SeqHIP(Vec, Vec);
 51: PETSC_INTERN PetscErrorCode VecRestoreLocalVectorRead_SeqHIP(Vec, Vec);
 52: PETSC_INTERN PetscErrorCode VecGetArrayWrite_SeqHIP(Vec, PetscScalar **);
 53: PETSC_INTERN PetscErrorCode VecGetArray_SeqHIP(Vec, PetscScalar **);
 54: PETSC_INTERN PetscErrorCode VecRestoreArray_SeqHIP(Vec, PetscScalar **);
 55: PETSC_INTERN PetscErrorCode VecGetArrayAndMemType_SeqHIP(Vec, PetscScalar **, PetscMemType *);
 56: PETSC_INTERN PetscErrorCode VecRestoreArrayAndMemType_SeqHIP(Vec, PetscScalar **);
 57: PETSC_INTERN PetscErrorCode VecGetArrayWriteAndMemType_SeqHIP(Vec, PetscScalar **, PetscMemType *);
 58: PETSC_INTERN PetscErrorCode VecCopy_SeqHIP_Private(Vec, Vec);
 59: PETSC_INTERN PetscErrorCode VecDestroy_SeqHIP_Private(Vec);
 60: PETSC_INTERN PetscErrorCode VecResetArray_SeqHIP_Private(Vec);
 61: PETSC_INTERN PetscErrorCode VecMax_SeqHIP(Vec, PetscInt *, PetscReal *);
 62: PETSC_INTERN PetscErrorCode VecMin_SeqHIP(Vec, PetscInt *, PetscReal *);
 63: PETSC_INTERN PetscErrorCode VecReciprocal_SeqHIP(Vec);
 64: PETSC_INTERN PetscErrorCode VecSum_SeqHIP(Vec, PetscScalar *);
 65: PETSC_INTERN PetscErrorCode VecShift_SeqHIP(Vec, PetscScalar);

 67: /* complex single */
 68: #if defined(PETSC_USE_COMPLEX)
 69:   #if defined(PETSC_USE_REAL_SINGLE)
 70:     #define hipblasXaxpy(a, b, c, d, e, f, g)                      hipblasCaxpy((a), (b), (hipblasComplex *)(c), (hipblasComplex *)(d), (e), (hipblasComplex *)(f), (g))
 71:     #define hipblasXscal(a, b, c, d, e)                            hipblasCscal((a), (b), (hipblasComplex *)(c), (hipblasComplex *)(d), (e))
 72:     #define hipblasXdotu(a, b, c, d, e, f, g)                      hipblasCdotu((a), (b), (hipblasComplex *)(c), (d), (hipblasComplex *)(e), (f), (hipblasComplex *)(g))
 73:     #define hipblasXdot(a, b, c, d, e, f, g)                       hipblasCdotc((a), (b), (hipblasComplex *)(c), (d), (hipblasComplex *)(e), (f), (hipblasComplex *)(g))
 74:     #define hipblasXswap(a, b, c, d, e, f)                         hipblasCswap((a), (b), (hipblasComplex *)(c), (d), (hipblasComplex *)(e), (f))
 75:     #define hipblasXnrm2(a, b, c, d, e)                            hipblasScnrm2((a), (b), (hipblasComplex *)(c), (d), (e))
 76:     #define hipblasIXamax(a, b, c, d, e)                           hipblasIcamax((a), (b), (hipblasComplex *)(c), (d), (e))
 77:     #define hipblasXasum(a, b, c, d, e)                            hipblasScasum((a), (b), (hipblasComplex *)(c), (d), (e))
 78:     #define hipblasXgemv(a, b, c, d, e, f, g, h, i, j, k, l)       hipblasCgemv((a), (b), (c), (d), (hipblasComplex *)(e), (hipblasComplex *)(f), (g), (hipblasComplex *)(h), (i), (hipblasComplex *)(j), (hipblasComplex *)(k), (l))
 79:     #define hipblasXgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipblasCgemm((a), (b), (c), (d), (e), (f), (hipblasComplex *)(g), (hipblasComplex *)(h), (i), (hipblasComplex *)(j), (k), (hipblasComplex *)(l), (hipblasComplex *)(m), (n))
 80:     #define hipblasXgeam(a, b, c, d, e, f, g, h, i, j, k, l, m)    hipblasCgeam((a), (b), (c), (d), (e), (hipblasComplex *)(f), (hipblasComplex *)(g), (h), (hipblasComplex *)(i), (hipblasComplex *)(j), (k), (hipblasComplex *)(l), (m))
 81:   #else /* complex double */
 82:     #define hipblasXaxpy(a, b, c, d, e, f, g) hipblasZaxpy((a), (b), (hipblasDoubleComplex *)(c), (hipblasDoubleComplex *)(d), (e), (hipblasDoubleComplex *)(f), (g))
 83:     #define hipblasXscal(a, b, c, d, e)       hipblasZscal((a), (b), (hipblasDoubleComplex *)(c), (hipblasDoubleComplex *)(d), (e))
 84:     #define hipblasXdotu(a, b, c, d, e, f, g) hipblasZdotu((a), (b), (hipblasDoubleComplex *)(c), (d), (hipblasDoubleComplex *)(e), (f), (hipblasDoubleComplex *)(g))
 85:     #define hipblasXdot(a, b, c, d, e, f, g)  hipblasZdotc((a), (b), (hipblasDoubleComplex *)(c), (d), (hipblasDoubleComplex *)(e), (f), (hipblasDoubleComplex *)(g))
 86:     #define hipblasXswap(a, b, c, d, e, f)    hipblasZswap((a), (b), (hipblasDoubleComplex *)(c), (d), (hipblasDoubleComplex *)(e), (f))
 87:     #define hipblasXnrm2(a, b, c, d, e)       hipblasDznrm2((a), (b), (hipblasDoubleComplex *)(c), (d), (e))
 88:     #define hipblasIXamax(a, b, c, d, e)      hipblasIzamax((a), (b), (hipblasDoubleComplex *)(c), (d), (e))
 89:     #define hipblasXasum(a, b, c, d, e)       hipblasDzasum((a), (b), (hipblasDoubleComplex *)(c), (d), (e))
 90:     #define hipblasXgemv(a, b, c, d, e, f, g, h, i, j, k, l) \
 91:       hipblasZgemv((a), (b), (c), (d), (hipblasDoubleComplex *)(e), (hipblasDoubleComplex *)(f), (g), (hipblasDoubleComplex *)(h), (i), (hipblasDoubleComplex *)(j), (hipblasDoubleComplex *)(k), (l))
 92:     #define hipblasXgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n) \
 93:       hipblasZgemm((a), (b), (c), (d), (e), (f), (hipblasDoubleComplex *)(g), (hipblasDoubleComplex *)(h), (i), (hipblasDoubleComplex *)(j), (k), (hipblasDoubleComplex *)(l), (hipblasDoubleComplex *)(m), (n))
 94:     #define hipblasXgeam(a, b, c, d, e, f, g, h, i, j, k, l, m) \
 95:       hipblasZgeam((a), (b), (c), (d), (e), (hipblasDoubleComplex *)(f), (hipblasDoubleComplex *)(g), (h), (hipblasDoubleComplex *)(i), (hipblasDoubleComplex *)(j), (k), (hipblasDoubleComplex *)(l), (m))
 96:   #endif
 97: #else /* real single */
 98:   #if defined(PETSC_USE_REAL_SINGLE)
 99:     #define hipblasXaxpy  hipblasSaxpy
100:     #define hipblasXscal  hipblasSscal
101:     #define hipblasXdotu  hipblasSdot
102:     #define hipblasXdot   hipblasSdot
103:     #define hipblasXswap  hipblasSswap
104:     #define hipblasXnrm2  hipblasSnrm2
105:     #define hipblasIXamax hipblasIsamax
106:     #define hipblasXasum  hipblasSasum
107:     #define hipblasXgemv  hipblasSgemv
108:     #define hipblasXgemm  hipblasSgemm
109:     #define hipblasXgeam  hipblasSgeam
110:   #else /* real double */
111:     #define hipblasXaxpy  hipblasDaxpy
112:     #define hipblasXscal  hipblasDscal
113:     #define hipblasXdotu  hipblasDdot
114:     #define hipblasXdot   hipblasDdot
115:     #define hipblasXswap  hipblasDswap
116:     #define hipblasXnrm2  hipblasDnrm2
117:     #define hipblasIXamax hipblasIdamax
118:     #define hipblasXasum  hipblasDasum
119:     #define hipblasXgemv  hipblasDgemv
120:     #define hipblasXgemm  hipblasDgemm
121:     #define hipblasXgeam  hipblasDgeam
122:   #endif
123: #endif

125: #endif // PETSC_HIPVECIMPL_H