Actual source code: cudavecimpl.h
1: #ifndef PETSC_CUDAVECIMPL_H
2: #define PETSC_CUDAVECIMPL_H
4: #include <petscvec.h>
5: #include <petscdevice_cuda.h>
6: #include <petsc/private/deviceimpl.h>
7: #include <petsc/private/vecimpl.h>
9: typedef struct {
10: PetscScalar *GPUarray; /* this always holds the GPU data */
11: PetscScalar *GPUarray_allocated; /* if the array was allocated by PETSc this is its pointer */
12: cudaStream_t stream; /* A stream for doing asynchronous data transfers */
13: PetscBool nvshmem; /* Is GPUarray_allocated allocated in nvshmem? It is used to allocate Mvctx->lvec in nvshmem */
15: /* COO stuff */
16: PetscCount *jmap1_d; /* [m+1]: i-th entry of the vector has jmap1[i+1]-jmap1[i] repeats in COO arrays */
17: PetscCount *perm1_d; /* [tot1]: permutation array for local entries */
18: PetscCount *imap2_d; /* [nnz2]: i-th unique entry in recvbuf is imap2[i]-th entry in the vector */
19: PetscCount *jmap2_d; /* [nnz2+1] */
20: PetscCount *perm2_d; /* [recvlen] */
21: PetscCount *Cperm_d; /* [sendlen]: permutation array to fill sendbuf[]. 'C' for communication */
22: PetscScalar *sendbuf_d, *recvbuf_d; /* Buffers for remote values in VecSetValuesCOO() */
23: } Vec_CUDA;
25: PETSC_INTERN PetscErrorCode VecCUDAGetArrays_Private(Vec, const PetscScalar **, const PetscScalar **, PetscOffloadMask *);
26: PETSC_INTERN PetscErrorCode VecDotNorm2_SeqCUDA(Vec, Vec, PetscScalar *, PetscScalar *);
27: PETSC_INTERN PetscErrorCode VecPointwiseDivide_SeqCUDA(Vec, Vec, Vec);
28: PETSC_INTERN PetscErrorCode VecWAXPY_SeqCUDA(Vec, PetscScalar, Vec, Vec);
29: PETSC_INTERN PetscErrorCode VecMDot_SeqCUDA(Vec, PetscInt, const Vec[], PetscScalar *);
30: PETSC_EXTERN PetscErrorCode VecSet_SeqCUDA(Vec, PetscScalar);
31: PETSC_INTERN PetscErrorCode VecMAXPY_SeqCUDA(Vec, PetscInt, const PetscScalar *, Vec *);
32: PETSC_INTERN PetscErrorCode VecAXPBYPCZ_SeqCUDA(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec);
33: PETSC_INTERN PetscErrorCode VecPointwiseMult_SeqCUDA(Vec, Vec, Vec);
34: PETSC_INTERN PetscErrorCode VecPlaceArray_SeqCUDA(Vec, const PetscScalar *);
35: PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA(Vec);
36: PETSC_INTERN PetscErrorCode VecReplaceArray_SeqCUDA(Vec, const PetscScalar *);
37: PETSC_INTERN PetscErrorCode VecDot_SeqCUDA(Vec, Vec, PetscScalar *);
38: PETSC_INTERN PetscErrorCode VecTDot_SeqCUDA(Vec, Vec, PetscScalar *);
39: PETSC_INTERN PetscErrorCode VecScale_SeqCUDA(Vec, PetscScalar);
40: PETSC_EXTERN PetscErrorCode VecCopy_SeqCUDA(Vec, Vec);
41: PETSC_INTERN PetscErrorCode VecSwap_SeqCUDA(Vec, Vec);
42: PETSC_EXTERN PetscErrorCode VecAXPY_SeqCUDA(Vec, PetscScalar, Vec);
43: PETSC_INTERN PetscErrorCode VecAXPBY_SeqCUDA(Vec, PetscScalar, PetscScalar, Vec);
44: PETSC_INTERN PetscErrorCode VecDuplicate_SeqCUDA(Vec, Vec *);
45: PETSC_INTERN PetscErrorCode VecConjugate_SeqCUDA(Vec xin);
46: PETSC_INTERN PetscErrorCode VecNorm_SeqCUDA(Vec, NormType, PetscReal *);
47: PETSC_INTERN PetscErrorCode VecCUDACopyToGPU(Vec);
48: PETSC_INTERN PetscErrorCode VecCUDAAllocateCheck(Vec);
49: PETSC_EXTERN PetscErrorCode VecCreate_SeqCUDA(Vec);
50: PETSC_INTERN PetscErrorCode VecCreate_SeqCUDA_Private(Vec, const PetscScalar *);
51: PETSC_INTERN PetscErrorCode VecCreate_MPICUDA(Vec);
52: PETSC_INTERN PetscErrorCode VecCreate_MPICUDA_Private(Vec, PetscBool, PetscInt, const PetscScalar *);
53: PETSC_INTERN PetscErrorCode VecCreate_CUDA(Vec);
54: PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA(Vec);
55: PETSC_INTERN PetscErrorCode VecDestroy_MPICUDA(Vec);
56: PETSC_INTERN PetscErrorCode VecAYPX_SeqCUDA(Vec, PetscScalar, Vec);
57: PETSC_INTERN PetscErrorCode VecSetRandom_SeqCUDA(Vec, PetscRandom);
58: PETSC_INTERN PetscErrorCode VecGetLocalVector_SeqCUDA(Vec, Vec);
59: PETSC_INTERN PetscErrorCode VecRestoreLocalVector_SeqCUDA(Vec, Vec);
60: PETSC_INTERN PetscErrorCode VecGetLocalVectorRead_SeqCUDA(Vec, Vec);
61: PETSC_INTERN PetscErrorCode VecRestoreLocalVectorRead_SeqCUDA(Vec, Vec);
62: PETSC_INTERN PetscErrorCode VecGetArrayWrite_SeqCUDA(Vec, PetscScalar **);
63: PETSC_INTERN PetscErrorCode VecGetArray_SeqCUDA(Vec, PetscScalar **);
64: PETSC_INTERN PetscErrorCode VecRestoreArray_SeqCUDA(Vec, PetscScalar **);
65: PETSC_INTERN PetscErrorCode VecGetArrayAndMemType_SeqCUDA(Vec, PetscScalar **, PetscMemType *);
66: PETSC_INTERN PetscErrorCode VecRestoreArrayAndMemType_SeqCUDA(Vec, PetscScalar **);
67: PETSC_INTERN PetscErrorCode VecGetArrayWriteAndMemType_SeqCUDA(Vec, PetscScalar **, PetscMemType *);
68: PETSC_INTERN PetscErrorCode VecCopy_SeqCUDA_Private(Vec, Vec);
69: PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA_Private(Vec);
70: PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA_Private(Vec);
71: PETSC_INTERN PetscErrorCode VecMax_SeqCUDA(Vec, PetscInt *, PetscReal *);
72: PETSC_INTERN PetscErrorCode VecMin_SeqCUDA(Vec, PetscInt *, PetscReal *);
73: PETSC_INTERN PetscErrorCode VecReciprocal_SeqCUDA(Vec);
74: PETSC_INTERN PetscErrorCode VecSum_SeqCUDA(Vec, PetscScalar *);
75: PETSC_INTERN PetscErrorCode VecShift_SeqCUDA(Vec, PetscScalar);
76: PETSC_INTERN PetscErrorCode VecSetPreallocationCOO_SeqCUDA(Vec, PetscCount, const PetscInt[]);
77: PETSC_INTERN PetscErrorCode VecSetValuesCOO_SeqCUDA(Vec, const PetscScalar[], InsertMode);
79: #if defined(PETSC_HAVE_NVSHMEM)
80: PETSC_EXTERN PetscErrorCode PetscNvshmemInitializeCheck(void);
81: PETSC_EXTERN PetscErrorCode PetscNvshmemMalloc(size_t, void **);
82: PETSC_EXTERN PetscErrorCode PetscNvshmemCalloc(size_t, void **);
83: PETSC_EXTERN PetscErrorCode PetscNvshmemFree_Private(void *);
84: #define PetscNvshmemFree(ptr) ((ptr) && (PetscNvshmemFree_Private(ptr), (ptr) = NULL, 0))
85: PETSC_INTERN PetscErrorCode PetscNvshmemSum(PetscInt, PetscScalar *, const PetscScalar *);
86: PETSC_INTERN PetscErrorCode PetscNvshmemMax(PetscInt, PetscReal *, const PetscReal *);
87: PETSC_INTERN PetscErrorCode VecNormAsync_NVSHMEM(Vec, NormType, PetscReal *);
88: PETSC_INTERN PetscErrorCode VecAllocateNVSHMEM_SeqCUDA(Vec);
89: #endif
91: /* complex single */
92: #if defined(PETSC_USE_COMPLEX)
93: #if defined(PETSC_USE_REAL_SINGLE)
94: #define cublasXaxpy(a, b, c, d, e, f, g) cublasCaxpy((a), (b), (cuComplex *)(c), (cuComplex *)(d), (e), (cuComplex *)(f), (g))
95: #define cublasXscal(a, b, c, d, e) cublasCscal((a), (b), (cuComplex *)(c), (cuComplex *)(d), (e))
96: #define cublasXdotu(a, b, c, d, e, f, g) cublasCdotu((a), (b), (cuComplex *)(c), (d), (cuComplex *)(e), (f), (cuComplex *)(g))
97: #define cublasXdot(a, b, c, d, e, f, g) cublasCdotc((a), (b), (cuComplex *)(c), (d), (cuComplex *)(e), (f), (cuComplex *)(g))
98: #define cublasXswap(a, b, c, d, e, f) cublasCswap((a), (b), (cuComplex *)(c), (d), (cuComplex *)(e), (f))
99: #define cublasXnrm2(a, b, c, d, e) cublasScnrm2((a), (b), (cuComplex *)(c), (d), (e))
100: #define cublasIXamax(a, b, c, d, e) cublasIcamax((a), (b), (cuComplex *)(c), (d), (e))
101: #define cublasXasum(a, b, c, d, e) cublasScasum((a), (b), (cuComplex *)(c), (d), (e))
102: #define cublasXgemv(a, b, c, d, e, f, g, h, i, j, k, l) cublasCgemv((a), (b), (c), (d), (cuComplex *)(e), (cuComplex *)(f), (g), (cuComplex *)(h), (i), (cuComplex *)(j), (cuComplex *)(k), (l))
103: #define cublasXgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n) cublasCgemm((a), (b), (c), (d), (e), (f), (cuComplex *)(g), (cuComplex *)(h), (i), (cuComplex *)(j), (k), (cuComplex *)(l), (cuComplex *)(m), (n))
104: #define cublasXgeam(a, b, c, d, e, f, g, h, i, j, k, l, m) cublasCgeam((a), (b), (c), (d), (e), (cuComplex *)(f), (cuComplex *)(g), (h), (cuComplex *)(i), (cuComplex *)(j), (k), (cuComplex *)(l), (m))
105: #else /* complex double */
106: #define cublasXaxpy(a, b, c, d, e, f, g) cublasZaxpy((a), (b), (cuDoubleComplex *)(c), (cuDoubleComplex *)(d), (e), (cuDoubleComplex *)(f), (g))
107: #define cublasXscal(a, b, c, d, e) cublasZscal((a), (b), (cuDoubleComplex *)(c), (cuDoubleComplex *)(d), (e))
108: #define cublasXdotu(a, b, c, d, e, f, g) cublasZdotu((a), (b), (cuDoubleComplex *)(c), (d), (cuDoubleComplex *)(e), (f), (cuDoubleComplex *)(g))
109: #define cublasXdot(a, b, c, d, e, f, g) cublasZdotc((a), (b), (cuDoubleComplex *)(c), (d), (cuDoubleComplex *)(e), (f), (cuDoubleComplex *)(g))
110: #define cublasXswap(a, b, c, d, e, f) cublasZswap((a), (b), (cuDoubleComplex *)(c), (d), (cuDoubleComplex *)(e), (f))
111: #define cublasXnrm2(a, b, c, d, e) cublasDznrm2((a), (b), (cuDoubleComplex *)(c), (d), (e))
112: #define cublasIXamax(a, b, c, d, e) cublasIzamax((a), (b), (cuDoubleComplex *)(c), (d), (e))
113: #define cublasXasum(a, b, c, d, e) cublasDzasum((a), (b), (cuDoubleComplex *)(c), (d), (e))
114: #define cublasXgemv(a, b, c, d, e, f, g, h, i, j, k, l) cublasZgemv((a), (b), (c), (d), (cuDoubleComplex *)(e), (cuDoubleComplex *)(f), (g), (cuDoubleComplex *)(h), (i), (cuDoubleComplex *)(j), (cuDoubleComplex *)(k), (l))
115: #define cublasXgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n) cublasZgemm((a), (b), (c), (d), (e), (f), (cuDoubleComplex *)(g), (cuDoubleComplex *)(h), (i), (cuDoubleComplex *)(j), (k), (cuDoubleComplex *)(l), (cuDoubleComplex *)(m), (n))
116: #define cublasXgeam(a, b, c, d, e, f, g, h, i, j, k, l, m) cublasZgeam((a), (b), (c), (d), (e), (cuDoubleComplex *)(f), (cuDoubleComplex *)(g), (h), (cuDoubleComplex *)(i), (cuDoubleComplex *)(j), (k), (cuDoubleComplex *)(l), (m))
117: #endif
118: #else /* real single */
119: #if defined(PETSC_USE_REAL_SINGLE)
120: #define cublasXaxpy cublasSaxpy
121: #define cublasXscal cublasSscal
122: #define cublasXdotu cublasSdot
123: #define cublasXdot cublasSdot
124: #define cublasXswap cublasSswap
125: #define cublasXnrm2 cublasSnrm2
126: #define cublasIXamax cublasIsamax
127: #define cublasXasum cublasSasum
128: #define cublasXgemv cublasSgemv
129: #define cublasXgemm cublasSgemm
130: #define cublasXgeam cublasSgeam
131: #else /* real double */
132: #define cublasXaxpy cublasDaxpy
133: #define cublasXscal cublasDscal
134: #define cublasXdotu cublasDdot
135: #define cublasXdot cublasDdot
136: #define cublasXswap cublasDswap
137: #define cublasXnrm2 cublasDnrm2
138: #define cublasIXamax cublasIdamax
139: #define cublasXasum cublasDasum
140: #define cublasXgemv cublasDgemv
141: #define cublasXgemm cublasDgemm
142: #define cublasXgeam cublasDgeam
143: #endif
144: #endif
146: #endif // PETSC_CUDAVECIMPL_H