Actual source code: sfnvshmem.cu

  1: #include <petsc/private/cudavecimpl.h>
  2: #include <../src/vec/is/sf/impls/basic/sfpack.h>
  3: #include <mpi.h>
  4: #include <nvshmem.h>
  5: #include <nvshmemx.h>

  7: PetscErrorCode PetscNvshmemInitializeCheck(void)
  8: {
  9:   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
 10:     nvshmemx_init_attr_t attr;
 11:     attr.mpi_comm = &PETSC_COMM_WORLD;
 12:     PetscDeviceInitialize(PETSC_DEVICE_CUDA);
 13:     nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);
 14:     PetscNvshmemInitialized = PETSC_TRUE;
 15:     PetscBeganNvshmem       = PETSC_TRUE;
 16:   }
 17:   return 0;
 18: }

 20: PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
 21: {
 22:   PetscNvshmemInitializeCheck();
 23:   *ptr = nvshmem_malloc(size);
 25:   return 0;
 26: }

 28: PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
 29: {
 30:   PetscNvshmemInitializeCheck();
 31:   *ptr = nvshmem_calloc(size, 1);
 33:   return 0;
 34: }

 36: PetscErrorCode PetscNvshmemFree_Private(void *ptr)
 37: {
 38:   nvshmem_free(ptr);
 39:   return 0;
 40: }

 42: PetscErrorCode PetscNvshmemFinalize(void)
 43: {
 44:   nvshmem_finalize();
 45:   return 0;
 46: }

 48: /* Free nvshmem related fields in the SF */
 49: PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
 50: {
 51:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

 53:   PetscFree2(bas->leafsigdisp, bas->leafbufdisp);
 54:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d);
 55:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d);
 56:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d);
 57:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d);

 59:   PetscFree2(sf->rootsigdisp, sf->rootbufdisp);
 60:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d);
 61:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d);
 62:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d);
 63:   PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d);
 64:   return 0;
 65: }

 67: /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
 68: static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
 69: {
 70:   cudaError_t    cerr;
 71:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
 72:   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
 73:   PetscMPIInt    tag;
 74:   MPI_Comm       comm;
 75:   MPI_Request   *rootreqs, *leafreqs;
 76:   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */

 78:   PetscObjectGetComm((PetscObject)sf, &comm);
 79:   PetscObjectGetNewTag((PetscObject)sf, &tag);

 81:   nRemoteRootRanks      = sf->nranks - sf->ndranks;
 82:   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
 83:   sf->nRemoteRootRanks  = nRemoteRootRanks;
 84:   bas->nRemoteLeafRanks = nRemoteLeafRanks;

 86:   PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs);

 88:   stmp[0] = nRemoteRootRanks;
 89:   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
 90:   stmp[2] = nRemoteLeafRanks;
 91:   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];

 93:   MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm);

 95:   sf->nRemoteRootRanksMax  = rtmp[0];
 96:   sf->leafbuflen_rmax      = rtmp[1];
 97:   bas->nRemoteLeafRanksMax = rtmp[2];
 98:   bas->rootbuflen_rmax     = rtmp[3];

100:   /* Total four rounds of MPI communications to set up the nvshmem fields */

102:   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
103:   PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp);
104:   for (i = 0; i < nRemoteRootRanks; i++) MPI_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i]); /* Leaves recv */
105:   for (i = 0; i < nRemoteLeafRanks; i++) MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm);                             /* Roots send. Note i changes, so we use MPI_Send. */
106:   MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE);

108:   for (i = 0; i < nRemoteRootRanks; i++) MPI_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i]); /* Leaves recv */
109:   for (i = 0; i < nRemoteLeafRanks; i++) {
110:     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
111:     MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm); /* Roots send. Note tmp changes, so we use MPI_Send. */
112:   }
113:   MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE);

115:   cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt));
116:   cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt));
117:   cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt));
118:   cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt));

120:   cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
121:   cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
122:   cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
123:   cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);

125:   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
126:   PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp);
127:   for (i = 0; i < nRemoteLeafRanks; i++) MPI_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]);
128:   for (i = 0; i < nRemoteRootRanks; i++) MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm);
129:   MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE);

131:   for (i = 0; i < nRemoteLeafRanks; i++) MPI_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]);
132:   for (i = 0; i < nRemoteRootRanks; i++) {
133:     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
134:     MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm);
135:   }
136:   MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE);

138:   cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt));
139:   cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt));
140:   cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt));
141:   cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt));

143:   cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
144:   cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
145:   cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);
146:   cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream);

148:   PetscFree2(rootreqs, leafreqs);
149:   return 0;
150: }

152: PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
153: {
154:   MPI_Comm    comm;
155:   PetscBool   isBasic;
156:   PetscMPIInt result = MPI_UNEQUAL;

158:   PetscObjectGetComm((PetscObject)sf, &comm);
159:   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
160:      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
161:   */
162:   sf->checked_nvshmem_eligibility = PETSC_TRUE;
163:   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
164:     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
165:     PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic);
166:     if (isBasic) MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result);
167:     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */

169:     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
170:        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
171:        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
172:     */
173:     if (sf->use_nvshmem) {
174:       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
175:       MPI_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm);
176:       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
177:     }
178:     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
179:   }

181:   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
182:   if (sf->use_nvshmem) {
183:     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
184:     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
185: #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
186:     MPI_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm);
188: #endif
189:     if (allCuda) {
190:       PetscNvshmemInitializeCheck();
191:       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
192:         PetscSFSetUp_Basic_NVSHMEM(sf);
193:         sf->setup_nvshmem = PETSC_TRUE;
194:       }
195:       *use_nvshmem = PETSC_TRUE;
196:     } else {
197:       *use_nvshmem = PETSC_FALSE;
198:     }
199:   } else {
200:     *use_nvshmem = PETSC_FALSE;
201:   }
202:   return 0;
203: }

205: /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
206: static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
207: {
208:   cudaError_t    cerr;
209:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
210:   PetscInt       buflen = (direction == PETSCSF_../../../../../..2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];

212:   if (buflen) {
213:     cudaEventRecord(link->dataReady, link->stream);
214:     cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0);
215:   }
216:   return 0;
217: }

219: /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
220: static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
221: {
222:   cudaError_t    cerr;
223:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
224:   PetscInt       buflen = (direction == PETSCSF_../../../../../..2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];

226:   /* If unpack to non-null device buffer, build the endRemoteComm dependence */
227:   if (buflen) {
228:     cudaEventRecord(link->endRemoteComm, link->remoteCommStream);
229:     cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0);
230:   }
231:   return 0;
232: }

234: /* Send/Put signals to remote ranks

236:  Input parameters:
237:   + n        - Number of remote ranks
238:   . sig      - Signal address in symmetric heap
239:   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
240:   . ranks    - remote ranks
241:   - newval   - Set signals to this value
242: */
243: __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
244: {
245:   int i = blockIdx.x * blockDim.x + threadIdx.x;

247:   /* Each thread puts one remote signal */
248:   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
249: }

251: /* Wait until local signals equal to the expected value and then set them to a new value

253:  Input parameters:
254:   + n        - Number of signals
255:   . sig      - Local signal address
256:   . expval   - expected value
257:   - newval   - Set signals to this new value
258: */
259: __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
260: {
261: #if 0
262:   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
263:   int i = blockIdx.x*blockDim.x + threadIdx.x;
264:   if (i < n) {
265:     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
266:     sig[i] = newval;
267:   }
268: #else
269:   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
270:   for (int i = 0; i < n; i++) sig[i] = newval;
271: #endif
272: }

274: /* ===========================================================================================================

276:    A set of routines to support receiver initiated communication using the get method

278:     The getting protocol is:

280:     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
281:     All signal variables have an initial value 0.

283:     Sender:                                 |  Receiver:
284:   1.  Wait ssig be 0, then set it to 1
285:   2.  Pack data into stand alone sbuf       |
286:   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
287:                                             |   2. Get data from remote sbuf to local rbuf
288:                                             |   3. Put 1 to sender's ssig
289:                                             |   4. Unpack data from local rbuf
290:    ===========================================================================================================*/
291: /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
292:    Sender waits for signals (from receivers) indicating receivers have finished getting data
293: */
294: PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
295: {
296:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
297:   uint64_t      *sig;
298:   PetscInt       n;

300:   if (direction == PETSCSF_../../../../../..2LEAF) { /* leaf ranks are getting data */
301:     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
302:     n   = bas->nRemoteLeafRanks;
303:   } else { /* LEAF2../../../../../.. */
304:     sig = link->leafSendSig;
305:     n   = sf->nRemoteRootRanks;
306:   }

308:   if (n) {
309:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
310:     cudaGetLastError();
311:   }
312:   return 0;
313: }

315: /* n thread blocks. Each takes in charge one remote rank */
316: __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
317: {
318:   int         bid = blockIdx.x;
319:   PetscMPIInt pe  = srcranks[bid];

321:   if (!nvshmem_ptr(src, pe)) {
322:     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
323:     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
324:   }
325: }

327: /* Start communication -- Get data in the given direction */
328: PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
329: {
330:   cudaError_t    cerr;
331:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

333:   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;

335:   char        *src, *dst;
336:   PetscInt    *srcdisp_h, *dstdisp_h;
337:   PetscInt    *srcdisp_d, *dstdisp_d;
338:   PetscMPIInt *srcranks_h;
339:   PetscMPIInt *srcranks_d, *dstranks_d;
340:   uint64_t    *dstsig;
341:   PetscInt    *dstsigdisp_d;

343:   PetscSFLinkBuildDependenceBegin(sf, link, direction);
344:   if (direction == PETSCSF_../../../../../..2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
345:     nsrcranks = sf->nRemoteRootRanks;
346:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */

348:     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
349:     srcdisp_d  = sf->rootbufdisp_d;
350:     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
351:     srcranks_d = sf->ranks_d;

353:     ndstranks = bas->nRemoteLeafRanks;
354:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */

356:     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
357:     dstdisp_d  = sf->roffset_d;
358:     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */

360:     dstsig       = link->leafRecvSig;
361:     dstsigdisp_d = bas->leafsigdisp_d;
362:   } else { /* src is leaf, dst is root; we will move data from src to dst */
363:     nsrcranks = bas->nRemoteLeafRanks;
364:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */

366:     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
367:     srcdisp_d  = bas->leafbufdisp_d;
368:     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
369:     srcranks_d = bas->iranks_d;

371:     ndstranks = sf->nRemoteRootRanks;
372:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */

374:     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
375:     dstdisp_d  = bas->ioffset_d;
376:     dstranks_d = sf->ranks_d; /* my (remote) root ranks */

378:     dstsig       = link->rootRecvSig;
379:     dstsigdisp_d = sf->rootsigdisp_d;
380:   }

382:   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
383:   if (ndstranks) {
384:     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
385:     cudaGetLastError();
386:   }

388:   /* dst waits for signals (permissions) from src ranks to start getting data */
389:   if (nsrcranks) {
390:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
391:     cudaGetLastError();
392:   }

394:   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */

396:   /* Count number of locally accessible src ranks, which should be a small number */
397:   for (int i = 0; i < nsrcranks; i++) {
398:     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
399:   }

401:   /* Get data from remotely accessible PEs */
402:   if (nLocallyAccessible < nsrcranks) {
403:     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
404:     cudaGetLastError();
405:   }

407:   /* Get data from locally accessible PEs */
408:   if (nLocallyAccessible) {
409:     for (int i = 0; i < nsrcranks; i++) {
410:       int pe = srcranks_h[i];
411:       if (nvshmem_ptr(src, pe)) {
412:         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
413:         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
414:       }
415:     }
416:   }
417:   return 0;
418: }

420: /* Finish the communication (can be done before Unpack)
421:    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
422: */
423: PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
424: {
425:   cudaError_t    cerr;
426:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
427:   uint64_t      *srcsig;
428:   PetscInt       nsrcranks, *srcsigdisp;
429:   PetscMPIInt   *srcranks;

431:   if (direction == PETSCSF_../../../../../..2LEAF) { /* leaf ranks are getting data */
432:     nsrcranks  = sf->nRemoteRootRanks;
433:     srcsig     = link->rootSendSig; /* I want to set their root signal */
434:     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
435:     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
436:   } else {                          /* LEAF2../../../../../.., root ranks are getting data */
437:     nsrcranks  = bas->nRemoteLeafRanks;
438:     srcsig     = link->leafSendSig;
439:     srcsigdisp = bas->leafsigdisp_d;
440:     srcranks   = bas->iranks_d;
441:   }

443:   if (nsrcranks) {
444:     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
445:     cudaGetLastError();
446:     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
447:     cudaGetLastError();
448:   }
449:   PetscSFLinkBuildDependenceEnd(sf, link, direction);
450:   return 0;
451: }

453: /* ===========================================================================================================

455:    A set of routines to support sender initiated communication using the put-based method (the default)

457:     The putting protocol is:

459:     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
460:     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
461:     is in nvshmem space.

463:     Sender:                                 |  Receiver:
464:                                             |
465:   1.  Pack data into sbuf                   |
466:   2.  Wait ssig be 0, then set it to 1      |
467:   3.  Put data to remote stand-alone rbuf   |
468:   4.  Fence // make sure 5 happens after 3  |
469:   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
470:                                             |   2. Unpack data from local rbuf
471:                                             |   3. Put 0 to sender's ssig
472:    ===========================================================================================================*/

474: /* n thread blocks. Each takes in charge one remote rank */
475: __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
476: {
477:   int         bid = blockIdx.x;
478:   PetscMPIInt pe  = dstranks[bid];

480:   if (!nvshmem_ptr(dst, pe)) {
481:     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
482:     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
483:     srcsig[bid] = 1;
484:     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
485:   }
486: }

488: /* one-thread kernel, which takes in charge all locally accessible */
489: __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
490: {
491:   for (int i = 0; i < ndstranks; i++) {
492:     int pe = dstranks[i];
493:     if (nvshmem_ptr(dst, pe)) {
494:       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
495:       srcsig[i] = 1;
496:     }
497:   }
498: }

500: /* Put data in the given direction  */
501: PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
502: {
503:   cudaError_t    cerr;
504:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
505:   PetscInt       ndstranks, nLocallyAccessible = 0;
506:   char          *src, *dst;
507:   PetscInt      *srcdisp_h, *dstdisp_h;
508:   PetscInt      *srcdisp_d, *dstdisp_d;
509:   PetscMPIInt   *dstranks_h;
510:   PetscMPIInt   *dstranks_d;
511:   uint64_t      *srcsig;

513:   PetscSFLinkBuildDependenceBegin(sf, link, direction);
514:   if (direction == PETSCSF_../../../../../..2LEAF) {                              /* put data in rootbuf to leafbuf  */
515:     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
516:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
517:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

519:     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
520:     srcdisp_d = bas->ioffset_d;
521:     srcsig    = link->rootSendSig;

523:     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
524:     dstdisp_d  = bas->leafbufdisp_d;
525:     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
526:     dstranks_d = bas->iranks_d;
527:   } else { /* put data in leafbuf to rootbuf */
528:     ndstranks = sf->nRemoteRootRanks;
529:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
530:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

532:     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
533:     srcdisp_d = sf->roffset_d;
534:     srcsig    = link->leafSendSig;

536:     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
537:     dstdisp_d  = sf->rootbufdisp_d;
538:     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
539:     dstranks_d = sf->ranks_d;
540:   }

542:   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */

544:   /* Count number of locally accessible neighbors, which should be a small number */
545:   for (int i = 0; i < ndstranks; i++) {
546:     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
547:   }

549:   /* For remotely accessible PEs, send data to them in one kernel call */
550:   if (nLocallyAccessible < ndstranks) {
551:     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
552:     cudaGetLastError();
553:   }

555:   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
556:   if (nLocallyAccessible) {
557:     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
558:     for (int i = 0; i < ndstranks; i++) {
559:       int pe = dstranks_h[i];
560:       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
561:         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
562:         /* Initiate the nonblocking communication */
563:         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
564:       }
565:     }
566:   }

568:   if (nLocallyAccessible) { nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */ }
569:   return 0;
570: }

572: /* A one-thread kernel. The thread takes in charge all remote PEs */
573: __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
574: {
575:   /* TODO: Shall we finished the non-blocking remote puts? */

577:   /* 1. Send a signal to each dst rank */

579:   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
580:      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
581:   */
582:   for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */

584:   /* 2. Wait for signals from src ranks (if any) */
585:   if (nsrcranks) {
586:     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
587:     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
588:   }
589: }

591: /* Finish the communication -- A receiver waits until it can access its receive buffer */
592: PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
593: {
594:   cudaError_t    cerr;
595:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
596:   PetscMPIInt   *dstranks;
597:   uint64_t      *dstsig;
598:   PetscInt       nsrcranks, ndstranks, *dstsigdisp;

600:   if (direction == PETSCSF_../../../../../..2LEAF) { /* put root data to leaf */
601:     nsrcranks = sf->nRemoteRootRanks;

603:     ndstranks  = bas->nRemoteLeafRanks;
604:     dstranks   = bas->iranks_d;      /* leaf ranks */
605:     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
606:     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
607:   } else {                           /* LEAF2../../../../../.. */
608:     nsrcranks = bas->nRemoteLeafRanks;

610:     ndstranks  = sf->nRemoteRootRanks;
611:     dstranks   = sf->ranks_d;
612:     dstsig     = link->rootRecvSig;
613:     dstsigdisp = sf->rootsigdisp_d;
614:   }

616:   if (nsrcranks || ndstranks) {
617:     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
618:     cudaGetLastError();
619:   }
620:   PetscSFLinkBuildDependenceEnd(sf, link, direction);
621:   return 0;
622: }

624: /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
625: PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
626: {
627:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
628:   uint64_t      *srcsig;
629:   PetscInt       nsrcranks, *srcsigdisp_d;
630:   PetscMPIInt   *srcranks_d;

632:   if (direction == PETSCSF_../../../../../..2LEAF) { /* I allow my root ranks to put data to me */
633:     nsrcranks    = sf->nRemoteRootRanks;
634:     srcsig       = link->rootSendSig; /* I want to set their send signals */
635:     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
636:     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
637:   } else {                            /* LEAF2../../../../../.. */
638:     nsrcranks    = bas->nRemoteLeafRanks;
639:     srcsig       = link->leafSendSig;
640:     srcsigdisp_d = bas->leafsigdisp_d;
641:     srcranks_d   = bas->iranks_d;
642:   }

644:   if (nsrcranks) {
645:     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
646:     cudaGetLastError();
647:   }
648:   return 0;
649: }

651: /* Destructor when the link uses nvshmem for communication */
652: static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
653: {
654:   cudaError_t cerr;

656:   cudaEventDestroy(link->dataReady);
657:   cudaEventDestroy(link->endRemoteComm);
658:   cudaStreamDestroy(link->remoteCommStream);

660:   /* nvshmem does not need buffers on host, which should be NULL */
661:   PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);
662:   PetscNvshmemFree(link->leafSendSig);
663:   PetscNvshmemFree(link->leafRecvSig);
664:   PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);
665:   PetscNvshmemFree(link->rootSendSig);
666:   PetscNvshmemFree(link->rootRecvSig);
667:   return 0;
668: }

670: PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
671: {
672:   cudaError_t    cerr;
673:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
674:   PetscSFLink   *p, link;
675:   PetscBool      match, rootdirect[2], leafdirect[2];
676:   int            greatestPriority;

678:   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
679:      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
680:   */
681:   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
682:     if (sf->use_nvshmem_get) {
683:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
684:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
685:     } else {
686:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
687:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
688:     }
689:   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
690:     if (sf->use_nvshmem_get) {
691:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
692:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
693:     } else {
694:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
695:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
696:     }
697:   } else {                                    /* PETSCSF_FETCH */
698:     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
699:     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
700:   }

702:   /* Look for free nvshmem links in cache */
703:   for (p = &bas->avail; (link = *p); p = &link->next) {
704:     if (link->use_nvshmem) {
705:       MPIPetsc_Type_compare(unit, link->unit, &match);
706:       if (match) {
707:         *p = link->next; /* Remove from available list */
708:         goto found;
709:       }
710:     }
711:   }
712:   PetscNew(&link);
713:   PetscSFLinkSetUp_Host(sf, link, unit);                                          /* Compute link->unitbytes, dup link->unit etc. */
714:   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscSFLinkSetUp_CUDA(sf, link, unit); /* Setup pack routines, streams etc */
715: #if defined(PETSC_HAVE_KOKKOS)
716:   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscSFLinkSetUp_Kokkos(sf, link, unit);
717: #endif

719:   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
720:   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;

722:   /* Init signals to zero */
723:   if (!link->rootSendSig) PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig);
724:   if (!link->rootRecvSig) PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig);
725:   if (!link->leafSendSig) PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig);
726:   if (!link->leafRecvSig) PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig);

728:   link->use_nvshmem = PETSC_TRUE;
729:   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
730:   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
731:   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
732:   link->Destroy = PetscSFLinkDestroy_NVSHMEM;
733:   if (sf->use_nvshmem_get) { /* get-based protocol */
734:     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
735:     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
736:     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
737:   } else { /* put-based protocol */
738:     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
739:     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
740:     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
741:   }

743:   cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority);
744:   cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority);

746:   cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming);
747:   cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming);

749: found:
750:   if (rootdirect[PETSCSF_REMOTE]) {
751:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
752:   } else {
753:     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);
754:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
755:   }

757:   if (leafdirect[PETSCSF_REMOTE]) {
758:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
759:   } else {
760:     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);
761:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
762:   }

764:   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
765:   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
766:   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
767:   link->leafdata                   = leafdata;
768:   link->next                       = bas->inuse;
769:   bas->inuse                       = link;
770:   *mylink                          = link;
771:   return 0;
772: }

774: #if defined(PETSC_USE_REAL_SINGLE)
775: PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
776: {
777:   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */

779:   PetscMPIIntCast(count, &num);
780:   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
781:   return 0;
782: }

784: PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
785: {
786:   PetscMPIInt num;

788:   PetscMPIIntCast(count, &num);
789:   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
790:   return 0;
791: }
792: #elif defined(PETSC_USE_REAL_DOUBLE)
793: PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
794: {
795:   PetscMPIInt num;

797:   PetscMPIIntCast(count, &num);
798:   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
799:   return 0;
800: }

802: PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
803: {
804:   PetscMPIInt num;

806:   PetscMPIIntCast(count, &num);
807:   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
808:   return 0;
809: }
810: #endif