Actual source code: sfkok.kokkos.cxx

  1: #include <../src/vec/is/sf/impls/basic/sfpack.h>

  3: #include <Kokkos_Core.hpp>

  5: using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
  6: using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
  7: using HostMemorySpace      = Kokkos::HostSpace;

  9: typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
 10: typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;

 12: typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
 13: typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;

 15: /*====================================================================================*/
 16: /*                             Regular operations                           */
 17: /*====================================================================================*/
 18: template <typename Type>
 19: struct Insert {
 20:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 21:   {
 22:     Type old = x;
 23:     x        = y;
 24:     return old;
 25:   }
 26: };
 27: template <typename Type>
 28: struct Add {
 29:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 30:   {
 31:     Type old = x;
 32:     x += y;
 33:     return old;
 34:   }
 35: };
 36: template <typename Type>
 37: struct Mult {
 38:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 39:   {
 40:     Type old = x;
 41:     x *= y;
 42:     return old;
 43:   }
 44: };
 45: template <typename Type>
 46: struct Min {
 47:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 48:   {
 49:     Type old = x;
 50:     x        = PetscMin(x, y);
 51:     return old;
 52:   }
 53: };
 54: template <typename Type>
 55: struct Max {
 56:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 57:   {
 58:     Type old = x;
 59:     x        = PetscMax(x, y);
 60:     return old;
 61:   }
 62: };
 63: template <typename Type>
 64: struct LAND {
 65:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 66:   {
 67:     Type old = x;
 68:     x        = x && y;
 69:     return old;
 70:   }
 71: };
 72: template <typename Type>
 73: struct LOR {
 74:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 75:   {
 76:     Type old = x;
 77:     x        = x || y;
 78:     return old;
 79:   }
 80: };
 81: template <typename Type>
 82: struct LXOR {
 83:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 84:   {
 85:     Type old = x;
 86:     x        = !x != !y;
 87:     return old;
 88:   }
 89: };
 90: template <typename Type>
 91: struct BAND {
 92:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
 93:   {
 94:     Type old = x;
 95:     x        = x & y;
 96:     return old;
 97:   }
 98: };
 99: template <typename Type>
100: struct BOR {
101:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
102:   {
103:     Type old = x;
104:     x        = x | y;
105:     return old;
106:   }
107: };
108: template <typename Type>
109: struct BXOR {
110:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
111:   {
112:     Type old = x;
113:     x        = x ^ y;
114:     return old;
115:   }
116: };
117: template <typename PairType>
118: struct Minloc {
119:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
120:   {
121:     PairType old = x;
122:     if (y.first < x.first) x = y;
123:     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
124:     return old;
125:   }
126: };
127: template <typename PairType>
128: struct Maxloc {
129:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
130:   {
131:     PairType old = x;
132:     if (y.first > x.first) x = y;
133:     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
134:     return old;
135:   }
136: };

138: /*====================================================================================*/
139: /*                             Atomic operations                            */
140: /*====================================================================================*/
141: template <typename Type>
142: struct AtomicInsert {
143:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
144: };
145: template <typename Type>
146: struct AtomicAdd {
147:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
148: };
149: template <typename Type>
150: struct AtomicBAND {
151:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
152: };
153: template <typename Type>
154: struct AtomicBOR {
155:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
156: };
157: template <typename Type>
158: struct AtomicBXOR {
159:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
160: };
161: template <typename Type>
162: struct AtomicLAND {
163:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
164:   {
165:     const Type zero = 0, one = ~0;
166:     Kokkos::atomic_and(&x, y ? one : zero);
167:   }
168: };
169: template <typename Type>
170: struct AtomicLOR {
171:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
172:   {
173:     const Type zero = 0, one = 1;
174:     Kokkos::atomic_or(&x, y ? one : zero);
175:   }
176: };
177: template <typename Type>
178: struct AtomicMult {
179:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
180: };
181: template <typename Type>
182: struct AtomicMin {
183:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
184: };
185: template <typename Type>
186: struct AtomicMax {
187:   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
188: };
189: /* TODO: struct AtomicLXOR  */
190: template <typename Type>
191: struct AtomicFetchAdd {
192:   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
193: };

195: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
196: static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
197: {
198:   PetscInt        i, j, k, m, n, r;
199:   const PetscInt *offset, *start, *dx, *dy, *X, *Y;

201:   n      = opt[0];
202:   offset = opt + 1;
203:   start  = opt + n + 2;
204:   dx     = opt + 2 * n + 2;
205:   dy     = opt + 3 * n + 2;
206:   X      = opt + 5 * n + 2;
207:   Y      = opt + 6 * n + 2;
208:   for (r = 0; r < n; r++) {
209:     if (tid < offset[r + 1]) break;
210:   }
211:   m = (tid - offset[r]);
212:   k = m / (dx[r] * dy[r]);
213:   j = (m - k * dx[r] * dy[r]) / dx[r];
214:   i = m - k * dx[r] * dy[r] - j * dx[r];

216:   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
217: }

219: /*====================================================================================*/
220: /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
221: /*====================================================================================*/

223: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
224:    <Type> is PetscReal, which is the primitive type we operate on.
225:    <bs>   is 16, which says <unit> contains 16 primitive types.
226:    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
227:    <EQ>   is 0, which is (bs == BS ? 1 : 0)

229:   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
230:   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
231: */
232: template <typename Type, PetscInt BS, PetscInt EQ>
233: static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
234: {
235:   const PetscInt      *iopt = opt ? opt->array : NULL;
236:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
237:   const Type          *data = static_cast<const Type *>(data_);
238:   Type                *buf  = static_cast<Type *>(buf_);
239:   DeviceExecutionSpace exec;

241:   Kokkos::parallel_for(
242:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
243:       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
244:        iopt == NULL && idx == NULL ==> the indices are contiguous;
245:      */
246:       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
247:       PetscInt s = tid * MBS;
248:       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
249:     });
250:   return 0;
251: }

253: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
254: static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
255: {
256:   Op                   op;
257:   const PetscInt      *iopt = opt ? opt->array : NULL;
258:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
259:   Type                *data = static_cast<Type *>(data_);
260:   const Type          *buf  = static_cast<const Type *>(buf_);
261:   DeviceExecutionSpace exec;

263:   Kokkos::parallel_for(
264:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
265:       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
266:       PetscInt s = tid * MBS;
267:       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
268:     });
269:   return 0;
270: }

272: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
273: static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
274: {
275:   Op                   op;
276:   const PetscInt      *ropt = opt ? opt->array : NULL;
277:   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
278:   Type                *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
279:   DeviceExecutionSpace exec;

281:   Kokkos::parallel_for(
282:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
283:       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
284:       PetscInt l = tid * MBS;
285:       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
286:     });
287:   return 0;
288: }

290: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
291: static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
292: {
293:   PetscInt             srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
294:   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
295:   const Type          *src = static_cast<const Type *>(src_);
296:   Type                *dst = static_cast<Type *>(dst_);
297:   DeviceExecutionSpace exec;

299:   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
300:   if (srcOpt) {
301:     srcx     = srcOpt->dx[0];
302:     srcy     = srcOpt->dy[0];
303:     srcX     = srcOpt->X[0];
304:     srcY     = srcOpt->Y[0];
305:     srcStart = srcOpt->start[0];
306:     srcIdx   = NULL;
307:   } else if (!srcIdx) {
308:     srcx = srcX = count;
309:     srcy = srcY = 1;
310:   }

312:   if (dstOpt) {
313:     dstx     = dstOpt->dx[0];
314:     dsty     = dstOpt->dy[0];
315:     dstX     = dstOpt->X[0];
316:     dstY     = dstOpt->Y[0];
317:     dstStart = dstOpt->start[0];
318:     dstIdx   = NULL;
319:   } else if (!dstIdx) {
320:     dstx = dstX = count;
321:     dsty = dstY = 1;
322:   }

324:   Kokkos::parallel_for(
325:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
326:       PetscInt i, j, k, s, t;
327:       Op       op;
328:       if (!srcIdx) { /* src is in 3D */
329:         k = tid / (srcx * srcy);
330:         j = (tid - k * srcx * srcy) / srcx;
331:         i = tid - k * srcx * srcy - j * srcx;
332:         s = srcStart + k * srcX * srcY + j * srcX + i;
333:       } else { /* src is contiguous */
334:         s = srcIdx[tid];
335:       }

337:       if (!dstIdx) { /* 3D */
338:         k = tid / (dstx * dsty);
339:         j = (tid - k * dstx * dsty) / dstx;
340:         i = tid - k * dstx * dsty - j * dstx;
341:         t = dstStart + k * dstX * dstY + j * dstX + i;
342:       } else { /* contiguous */
343:         t = dstIdx[tid];
344:       }

346:       s *= MBS;
347:       t *= MBS;
348:       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
349:     });
350:   return 0;
351: }

353: /* Specialization for Insert since we may use memcpy */
354: template <typename Type, PetscInt BS, PetscInt EQ>
355: static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
356: {
357:   const Type          *src = static_cast<const Type *>(src_);
358:   Type                *dst = static_cast<Type *>(dst_);
359:   DeviceExecutionSpace exec;

361:   if (!count) return 0;
362:   /*src and dst are contiguous */
363:   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
364:     size_t              sz = count * link->unitbytes;
365:     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
366:     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
367:     Kokkos::deep_copy(exec, dbuf, sbuf);
368:   } else {
369:     ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst);
370:   }
371:   return 0;
372: }

374: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
375: static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
376: {
377:   Op                   op;
378:   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
379:   const PetscInt      *ropt     = rootopt ? rootopt->array : NULL;
380:   const PetscInt      *lopt     = leafopt ? leafopt->array : NULL;
381:   Type                *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
382:   const Type          *leafdata = static_cast<const Type *>(leafdata_);
383:   DeviceExecutionSpace exec;

385:   Kokkos::parallel_for(
386:     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
387:       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
388:       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
389:       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
390:     });
391:   return 0;
392: }

394: /*====================================================================================*/
395: /*  Init various types and instantiate pack/unpack function pointers                  */
396: /*====================================================================================*/
397: template <typename Type, PetscInt BS, PetscInt EQ>
398: static void PackInit_RealType(PetscSFLink link)
399: {
400:   /* Pack/unpack for remote communication */
401:   link->d_Pack            = Pack<Type, BS, EQ>;
402:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
403:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
404:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
405:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
406:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
407:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
408:   /* Scatter for local communication */
409:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
410:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
411:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
412:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
413:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
414:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
415:   /* Atomic versions when there are data-race possibilities */
416:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
417:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
418:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
419:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
420:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
421:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

423:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
424:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
425:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
426:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
427:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
428:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
429: }

431: template <typename Type, PetscInt BS, PetscInt EQ>
432: static void PackInit_IntegerType(PetscSFLink link)
433: {
434:   link->d_Pack            = Pack<Type, BS, EQ>;
435:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
436:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
437:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
438:   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
439:   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
440:   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
441:   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
442:   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
443:   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
444:   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
445:   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
446:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

448:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
449:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
450:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
451:   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
452:   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
453:   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
454:   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
455:   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
456:   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
457:   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
458:   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
459:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

461:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
462:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
463:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
464:   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
465:   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
466:   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
467:   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
468:   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
469:   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
470:   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
471:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

473:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
474:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
475:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
476:   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
477:   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
478:   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
479:   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
480:   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
481:   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
482:   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
483:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
484: }

486: #if defined(PETSC_HAVE_COMPLEX)
487: template <typename Type, PetscInt BS, PetscInt EQ>
488: static void PackInit_ComplexType(PetscSFLink link)
489: {
490:   link->d_Pack            = Pack<Type, BS, EQ>;
491:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
492:   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
493:   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
494:   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;

496:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
497:   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
498:   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
499:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;

501:   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
502:   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
503:   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
504:   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;

506:   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
507:   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
508:   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
509:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
510: }
511: #endif

513: template <typename Type>
514: static void PackInit_PairType(PetscSFLink link)
515: {
516:   link->d_Pack            = Pack<Type, 1, 1>;
517:   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
518:   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
519:   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;

521:   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
522:   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
523:   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
524:   /* Atomics for pair types are not implemented yet */
525: }

527: template <typename Type, PetscInt BS, PetscInt EQ>
528: static void PackInit_DumbType(PetscSFLink link)
529: {
530:   link->d_Pack             = Pack<Type, BS, EQ>;
531:   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
532:   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
533:   /* Atomics for dumb types are not implemented yet */
534: }

536: /*
537:   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
538:   that one is not able to repeatedly create and destroy the object. SF's original design was each
539:   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
540:   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
541:   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
542:   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
543:   object in Kokkos.
544: */
545: /*
546: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
547: {
548:   return 0;
549: }
550: */

552: /* Some device-specific utilities */
553: static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
554: {
555:   Kokkos::fence();
556:   return 0;
557: }

559: static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
560: {
561:   DeviceExecutionSpace exec;
562:   exec.fence();
563:   return 0;
564: }

566: static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
567: {
568:   DeviceExecutionSpace exec;

570:   if (!n) return 0;
571:   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
572:     PetscMemcpy(dst, src, n);
573:   } else {
574:     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
575:       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
576:       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
577:       Kokkos::deep_copy(exec, dbuf, sbuf);
578:       PetscLogCpuToGpu(n);
579:     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
580:       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
581:       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
582:       Kokkos::deep_copy(exec, dbuf, sbuf);
583:       PetscLogGpuToCpu(n);
584:     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
585:       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
586:       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
587:       Kokkos::deep_copy(exec, dbuf, sbuf);
588:     }
589:   }
590:   return 0;
591: }

593: PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
594: {
595:   if (PetscMemTypeHost(mtype)) PetscMalloc(size, ptr);
596:   else if (PetscMemTypeDevice(mtype)) {
597:     if (!PetscKokkosInitialized) PetscKokkosInitializeCheck();
598:     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
599:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
600:   return 0;
601: }

603: PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
604: {
605:   if (PetscMemTypeHost(mtype)) PetscFree(ptr);
606:   else if (PetscMemTypeDevice(mtype)) {
607:     Kokkos::kokkos_free<DeviceMemorySpace>(ptr);
608:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
609:   return 0;
610: }

612: /* Destructor when the link uses MPI for communication */
613: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
614: {
615:   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
616:     PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]);
617:     PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]);
618:   }
619:   return 0;
620: }

622: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
623: PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
624: {
625:   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
626:   PetscBool is2Int, is2PetscInt;
627: #if defined(PETSC_HAVE_COMPLEX)
628:   PetscInt nPetscComplex = 0;
629: #endif

631:   if (link->deviceinited) return 0;
632:   PetscKokkosInitializeCheck();
633:   MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar);
634:   MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar);
635:   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
636:   MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt);
637:   MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt);
638:   MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal);
639: #if defined(PETSC_HAVE_COMPLEX)
640:   MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex);
641: #endif
642:   MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int);
643:   MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt);

645:   if (is2Int) {
646:     PackInit_PairType<Kokkos::pair<int, int>>(link);
647:   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
648:     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
649:   } else if (nPetscReal) {
650: #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
651:     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
652:     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
653:     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
654:     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
655:     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
656:     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
657:     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
658:     else if (nPetscReal % 1 == 0)
659: #endif
660:       PackInit_RealType<PetscReal, 1, 0>(link);
661:   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
662: #if !defined(PETSC_HAVE_DEVICE)
663:     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
664:     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
665:     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
666:     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
667:     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
668:     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
669:     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
670:     else if (nPetscInt % 1 == 0)
671: #endif
672:       PackInit_IntegerType<llint, 1, 0>(link);
673:   } else if (nInt) {
674: #if !defined(PETSC_HAVE_DEVICE)
675:     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
676:     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
677:     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
678:     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
679:     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
680:     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
681:     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
682:     else if (nInt % 1 == 0)
683: #endif
684:       PackInit_IntegerType<int, 1, 0>(link);
685:   } else if (nSignedChar) {
686: #if !defined(PETSC_HAVE_DEVICE)
687:     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
688:     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
689:     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
690:     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
691:     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
692:     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
693:     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
694:     else if (nSignedChar % 1 == 0)
695: #endif
696:       PackInit_IntegerType<char, 1, 0>(link);
697:   } else if (nUnsignedChar) {
698: #if !defined(PETSC_HAVE_DEVICE)
699:     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
700:     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
701:     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
702:     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
703:     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
704:     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
705:     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
706:     else if (nUnsignedChar % 1 == 0)
707: #endif
708:       PackInit_IntegerType<unsigned char, 1, 0>(link);
709: #if defined(PETSC_HAVE_COMPLEX)
710:   } else if (nPetscComplex) {
711:   #if !defined(PETSC_HAVE_DEVICE)
712:     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
713:     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
714:     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
715:     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
716:     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
717:     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
718:     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
719:     else if (nPetscComplex % 1 == 0)
720:   #endif
721:       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
722: #endif
723:   } else {
724:     MPI_Aint lb, nbyte;
725:     MPI_Type_get_extent(unit, &lb, &nbyte);
727:     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
728: #if !defined(PETSC_HAVE_DEVICE)
729:       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
730:       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
731:       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
732:       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
733:       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
734:       else if (nbyte % 1 == 0)
735: #endif
736:         PackInit_DumbType<char, 1, 0>(link);
737:     } else {
738:       nInt = nbyte / sizeof(int);
739: #if !defined(PETSC_HAVE_DEVICE)
740:       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
741:       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
742:       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
743:       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
744:       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
745:       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
746:       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
747:       else if (nInt % 1 == 0)
748: #endif
749:         PackInit_DumbType<int, 1, 0>(link);
750:     }
751:   }

753:   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
754:   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
755:   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
756:   link->Destroy      = PetscSFLinkDestroy_Kokkos;
757:   link->deviceinited = PETSC_TRUE;
758:   return 0;
759: }