Actual source code: cupmcontext.hpp

  1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
  2: #define PETSCDEVICECONTEXTCUPM_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupmblasinterface.hpp>
  6: #include <petsc/private/logimpl.h>

  8: #include <petsc/private/cpp/array.hpp>

 10: #include "../segmentedmempool.hpp"
 11: #include "cupmallocator.hpp"
 12: #include "cupmstream.hpp"
 13: #include "cupmevent.hpp"

 15: #if defined(__cplusplus)

 17: namespace Petsc
 18: {

 20: namespace device
 21: {

 23: namespace cupm
 24: {

 26: namespace impl
 27: {

 29: template <DeviceType T>
 30: class DeviceContext : BlasInterface<T> {
 31: public:
 32:   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);

 34: private:
 35:   template <typename H, std::size_t>
 36:   struct HandleTag {
 37:     using type = H;
 38:   };

 40:   using stream_tag = HandleTag<cupmStream_t, 0>;
 41:   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
 42:   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;

 44:   using stream_type = CUPMStream<T>;
 45:   using event_type  = CUPMEvent<T>;

 47: public:
 48:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 49:   // header, but since we are using the power of templates it must be declared part of
 50:   // this class to have easy access the same typedefs. Technically one can make a
 51:   // templated struct outside the class but it's more code for the same result.
 52:   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
 53:     stream_type stream{};
 54:     cupmEvent_t event{};
 55:     cupmEvent_t begin{}; // timer-only
 56:     cupmEvent_t end{};   // timer-only
 57:   #if PetscDefined(USE_DEBUG)
 58:     PetscBool timerInUse{};
 59:   #endif
 60:     cupmBlasHandle_t   blas{};
 61:     cupmSolverHandle_t solver{};

 63:     constexpr PetscDeviceContext_IMPLS() noexcept = default;

 65:     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }

 67:     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }

 69:     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
 70:   };

 72: private:
 73:   static bool                                                     initialized_;
 74:   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 75:   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 77:   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }

 79:   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }

 81:   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }

 83:   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }

 85:   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
 86:   // handles
 87:   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { return 0; }

 89:   PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept
 90:   {
 91:     PetscLogEvent event;

 93:     if (PetscLikely(handle)) return 0;
 94:     PetscLogPauseCurrentEvent_Internal(&event);
 95:     PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0);
 96:     for (auto i = 0; i < 3; ++i) {
 97:       auto cberr = cupmBlasCreate(&handle);
 98:       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
 99:       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) cberr;
100:       if (i != 2) {
101:         PetscSleep(3);
102:         continue;
103:       }
105:     }
106:     PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0);
107:     PetscLogEventResume_Internal(event);
108:     return 0;
109:   }

111:   PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept
112:   {
113:     const auto dci    = impls_cast_(dctx);
114:     auto      &handle = blashandles_[dctx->device->deviceId];

116:     create_handle_(tag, handle);
117:     cupmBlasSetStream(handle, dci->stream.get_stream());
118:     dci->blas = handle;
119:     return 0;
120:   }

122:   PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(solver_tag, cupmSolverHandle_t &handle))
123:   {
124:     PetscLogEvent event;

126:     PetscLogPauseCurrentEvent_Internal(&event);
127:     PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0);
128:     cupmBlasInterface_t::InitializeHandle(handle);
129:     PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0);
130:     PetscLogEventResume_Internal(event);
131:     return 0;
132:   }

134:   PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag tag, PetscDeviceContext dctx) noexcept
135:   {
136:     const auto dci    = impls_cast_(dctx);
137:     auto      &handle = solverhandles_[dctx->device->deviceId];

139:     create_handle_(tag, handle);
140:     cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream());
141:     dci->solver = handle;
142:     return 0;
143:   }

145:   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
146:   {
147:     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;

150:                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
151:     PetscDeviceCheckDeviceCount_Internal(devidl);
152:     PetscDeviceCheckDeviceCount_Internal(devidr);
153:     cupmSetDevice(static_cast<int>(devidl));
154:     return 0;
155:   }

157:   PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }

159:   PETSC_NODISCARD static PetscErrorCode finalize_() noexcept
160:   {
161:     for (auto &&handle : blashandles_) {
162:       if (handle) {
163:         cupmBlasDestroy(handle);
164:         handle = nullptr;
165:       }
166:     }
167:     for (auto &&handle : solverhandles_) {
168:       if (handle) {
169:         cupmBlasInterface_t::DestroyHandle(handle);
170:         handle = nullptr;
171:       }
172:     }
173:     initialized_ = false;
174:     return 0;
175:   }

177:   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
178:   PETSC_NODISCARD static PoolType &default_pool_() noexcept
179:   {
180:     static PoolType pool;
181:     return pool;
182:   }

184:   PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
185:   {
187:     return 0;
188:   }

190: public:
191:   // All of these functions MUST be static in order to be callable from C, otherwise they
192:   // get the implicit 'this' pointer tacked on
193:   PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
194:   PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType));
195:   PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
196:   PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *));
197:   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext));
198:   PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
199:   template <typename Handle_t>
200:   PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *));
201:   PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
202:   PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *));
203:   PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **));
204:   PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **));
205:   PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode));
206:   PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t));
207:   PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent));
208:   PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent));
209:   PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent));

211:   // not a PetscDeviceContext method, this registers the class
212:   PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());

214:   // clang-format off
215:   const _DeviceContextOps ops = {
216:     destroy,
217:     changeStreamType,
218:     setUp,
219:     query,
220:     waitForContext,
221:     synchronize,
222:     getHandle<blas_tag>,
223:     getHandle<solver_tag>,
224:     getHandle<stream_tag>,
225:     beginTimer,
226:     endTimer,
227:     memAlloc,
228:     memFree,
229:     memCopy,
230:     memSet,
231:     createEvent,
232:     recordEvent,
233:     waitForEvent
234:   };
235:   // clang-format on
236: };

238: // not a PetscDeviceContext method, this initializes the CLASS
239: template <DeviceType T>
240: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize())
241: {
242:   if (PetscUnlikely(!initialized_)) {
243:     cupmMemPool_t mempool;
244:     uint64_t      threshold = UINT64_MAX;

246:     initialized_ = true;
247:     cupmDeviceGetMemPool(&mempool, 0);
248:     cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold);
249:     blashandles_.fill(nullptr);
250:     solverhandles_.fill(nullptr);
251:     PetscRegisterFinalize(finalize_);
252:   }
253:   return 0;
254: }

256: template <DeviceType T>
257: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
258: {
259:   if (const auto dci = impls_cast_(dctx)) {
260:     dci->stream.destroy();
261:     if (dci->event) cupm_fast_event_pool<T>().deallocate(std::move(dci->event));
262:     if (dci->begin) cupmEventDestroy(dci->begin);
263:     if (dci->end) cupmEventDestroy(dci->end);
264:     delete dci;
265:     dctx->data = nullptr;
266:   }
267:   return 0;
268: }

270: template <DeviceType T>
271: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
272: {
273:   const auto dci = impls_cast_(dctx);

275:   dci->stream.destroy();
276:   // set these to null so they aren't usable until setup is called again
277:   dci->blas   = nullptr;
278:   dci->solver = nullptr;
279:   return 0;
280: }

282: template <DeviceType T>
283: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
284: {
285:   const auto dci   = impls_cast_(dctx);
286:   auto      &event = dci->event;

288:   check_current_device_(dctx);
289:   dci->stream.change_type(dctx->streamType);
290:   if (!event) cupm_fast_event_pool<T>().allocate(&event);
291:   #if PetscDefined(USE_DEBUG)
292:   dci->timerInUse = PETSC_FALSE;
293:   #endif
294:   return 0;
295: }

297: template <DeviceType T>
298: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
299: {
300:   check_current_device_(dctx);
301:   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
302:   case cupmSuccess:
303:     *idle = PETSC_TRUE;
304:     break;
305:   case cupmErrorNotReady:
306:     *idle = PETSC_FALSE;
307:     // reset the error
308:     cerr = cupmGetLastError();
309:     static_cast<void>(cerr);
310:     break;
311:   default:
312:     cerr;
313:     PetscUnreachable();
314:   }
315:   return 0;
316: }

318: template <DeviceType T>
319: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
320: {
321:   const auto dcib  = impls_cast_(dctxb);
322:   const auto event = dcib->event;

324:   check_current_device_(dctxa, dctxb);
325:   cupmEventRecord(event, dcib->stream.get_stream());
326:   cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0);
327:   return 0;
328: }

330: template <DeviceType T>
331: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
332: {
333:   auto idle = PETSC_TRUE;

335:   query(dctx, &idle);
336:   if (!idle) cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream());
337:   return 0;
338: }

340: template <DeviceType T>
341: template <typename handle_t>
342: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
343: {
344:   initialize_handle_(handle_t{}, dctx);
345:   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
346:   return 0;
347: }

349: template <DeviceType T>
350: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
351: {
352:   const auto dci = impls_cast_(dctx);

354:   check_current_device_(dctx);
355:   #if PetscDefined(USE_DEBUG)
357:   dci->timerInUse = PETSC_TRUE;
358:   #endif
359:   if (!dci->begin) {
360:     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
361:     cupmEventCreate(&dci->begin);
362:     cupmEventCreate(&dci->end);
363:   }
364:   cupmEventRecord(dci->begin, dci->stream.get_stream());
365:   return 0;
366: }

368: template <DeviceType T>
369: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
370: {
371:   float      gtime;
372:   const auto dci = impls_cast_(dctx);
373:   const auto end = dci->end;

375:   check_current_device_(dctx);
376:   #if PetscDefined(USE_DEBUG)
378:   dci->timerInUse = PETSC_FALSE;
379:   #endif
380:   cupmEventRecord(end, dci->stream.get_stream());
381:   cupmEventSynchronize(end);
382:   cupmEventElapsedTime(&gtime, dci->begin, end);
383:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
384:   return 0;
385: }

387: template <DeviceType T>
388: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest))
389: {
390:   const auto &stream = impls_cast_(dctx)->stream;

392:   check_current_device_(dctx);
393:   check_memtype_(mtype, "allocating");
394:   if (PetscMemTypeHost(mtype)) {
395:     default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment);
396:   } else {
397:     default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment);
398:   }
399:   if (clear) cupmMemsetAsync(*dest, 0, n, stream.get_stream());
400:   return 0;
401: }

403: template <DeviceType T>
404: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr))
405: {
406:   const auto &stream = impls_cast_(dctx)->stream;

408:   check_current_device_(dctx);
409:   check_memtype_(mtype, "freeing");
410:   if (!*ptr) return 0;
411:   if (PetscMemTypeHost(mtype)) {
412:     default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream);
413:     // if ptr exists still exists the pool didn't own it
414:     if (*ptr) {
415:       auto registered = PETSC_FALSE, managed = PETSC_FALSE;

417:       PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed);
418:       if (registered) {
419:         cupmFreeHost(*ptr);
420:       } else if (managed) {
421:         cupmFreeAsync(*ptr, stream.get_stream());
422:       }
423:     }
424:   } else {
425:     default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream);
426:     // if ptr exists still exists the pool didn't own it
427:     if (*ptr) cupmFreeAsync(*ptr, stream.get_stream());
428:   }
429:   return 0;
430: }

432: template <DeviceType T>
433: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode))
434: {
435:   const auto stream = impls_cast_(dctx)->stream.get_stream();

437:   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
438:   if (mode == PETSC_DEVICE_COPY_HTOH) {
439:     const auto cerr = cupmStreamQuery(stream);

441:     // yes this is faster
442:     if (cerr == cupmSuccess) {
443:       PetscMemcpy(dest, src, n);
444:       return 0;
445:     } else if (cerr == cupmErrorNotReady) {
446:       auto PETSC_UNUSED unused = cupmGetLastError();

448:       static_cast<void>(unused);
449:     } else {
450:       cerr;
451:     }
452:   }
453:   cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream);
454:   return 0;
455: }

457: template <DeviceType T>
458: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n))
459: {
460:   check_current_device_(dctx);
461:   check_memtype_(mtype, "zeroing");
462:   cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream());
463:   return 0;
464: }

466: template <DeviceType T>
467: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event))
468: {
469:   event->data = new event_type();
470:   event->destroy = [](PetscEvent event) {
471:     delete event_cast_(event);
472:     event->data = nullptr;
473:     return 0;
474:   };
475:   return 0;
476: }

478: template <DeviceType T>
479: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event))
480: {
481:   impls_cast_(dctx)->stream.record_event(*event_cast_(event));
482:   return 0;
483: }

485: template <DeviceType T>
486: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event))
487: {
488:   impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event));
489:   return 0;
490: }

492: // initialize the static member variables
493: template <DeviceType T>
494: bool DeviceContext<T>::initialized_ = false;

496: template <DeviceType T>
497: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};

499: template <DeviceType T>
500: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

502: } // namespace impl

504: // shorten this one up a bit (and instantiate the templates)
505: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
506: using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;

508:   // shorthand for what is an EXTREMELY long name
509:   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

511: } // namespace cupm

513: } // namespace device

515: } // namespace Petsc

517: #endif // __cplusplus

519: #endif // PETSCDEVICECONTEXTCUDA_HPP