Actual source code: cupmcontext.hpp
1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
2: #define PETSCDEVICECONTEXTCUPM_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupmblasinterface.hpp>
6: #include <petsc/private/logimpl.h>
8: #include <petsc/private/cpp/array.hpp>
10: #include "../segmentedmempool.hpp"
11: #include "cupmallocator.hpp"
12: #include "cupmstream.hpp"
13: #include "cupmevent.hpp"
15: #if defined(__cplusplus)
17: namespace Petsc
18: {
20: namespace device
21: {
23: namespace cupm
24: {
26: namespace impl
27: {
29: template <DeviceType T>
30: class DeviceContext : BlasInterface<T> {
31: public:
32: PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);
34: private:
35: template <typename H, std::size_t>
36: struct HandleTag {
37: using type = H;
38: };
40: using stream_tag = HandleTag<cupmStream_t, 0>;
41: using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
42: using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
44: using stream_type = CUPMStream<T>;
45: using event_type = CUPMEvent<T>;
47: public:
48: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
49: // header, but since we are using the power of templates it must be declared part of
50: // this class to have easy access the same typedefs. Technically one can make a
51: // templated struct outside the class but it's more code for the same result.
52: struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
53: stream_type stream{};
54: cupmEvent_t event{};
55: cupmEvent_t begin{}; // timer-only
56: cupmEvent_t end{}; // timer-only
57: #if PetscDefined(USE_DEBUG)
58: PetscBool timerInUse{};
59: #endif
60: cupmBlasHandle_t blas{};
61: cupmSolverHandle_t solver{};
63: constexpr PetscDeviceContext_IMPLS() noexcept = default;
65: PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }
67: PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }
69: PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
70: };
72: private:
73: static bool initialized_;
74: static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
75: static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
77: PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
79: PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
81: PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
83: PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
85: // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
86: // handles
87: PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext)) { return 0; }
89: PETSC_NODISCARD static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept
90: {
91: PetscLogEvent event;
93: if (PetscLikely(handle)) return 0;
94: PetscLogPauseCurrentEvent_Internal(&event);
95: PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0);
96: for (auto i = 0; i < 3; ++i) {
97: auto cberr = cupmBlasCreate(&handle);
98: if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
99: if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) cberr;
100: if (i != 2) {
101: PetscSleep(3);
102: continue;
103: }
105: }
106: PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0);
107: PetscLogEventResume_Internal(event);
108: return 0;
109: }
111: PETSC_NODISCARD static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept
112: {
113: const auto dci = impls_cast_(dctx);
114: auto &handle = blashandles_[dctx->device->deviceId];
116: create_handle_(tag, handle);
117: cupmBlasSetStream(handle, dci->stream.get_stream());
118: dci->blas = handle;
119: return 0;
120: }
122: PETSC_CXX_COMPAT_DECL(PetscErrorCode create_handle_(solver_tag, cupmSolverHandle_t &handle))
123: {
124: PetscLogEvent event;
126: PetscLogPauseCurrentEvent_Internal(&event);
127: PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0);
128: cupmBlasInterface_t::InitializeHandle(handle);
129: PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0);
130: PetscLogEventResume_Internal(event);
131: return 0;
132: }
134: PETSC_NODISCARD static PetscErrorCode initialize_handle_(solver_tag tag, PetscDeviceContext dctx) noexcept
135: {
136: const auto dci = impls_cast_(dctx);
137: auto &handle = solverhandles_[dctx->device->deviceId];
139: create_handle_(tag, handle);
140: cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream());
141: dci->solver = handle;
142: return 0;
143: }
145: PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
146: {
147: const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
150: PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
151: PetscDeviceCheckDeviceCount_Internal(devidl);
152: PetscDeviceCheckDeviceCount_Internal(devidr);
153: cupmSetDevice(static_cast<int>(devidl));
154: return 0;
155: }
157: PETSC_NODISCARD static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
159: PETSC_NODISCARD static PetscErrorCode finalize_() noexcept
160: {
161: for (auto &&handle : blashandles_) {
162: if (handle) {
163: cupmBlasDestroy(handle);
164: handle = nullptr;
165: }
166: }
167: for (auto &&handle : solverhandles_) {
168: if (handle) {
169: cupmBlasInterface_t::DestroyHandle(handle);
170: handle = nullptr;
171: }
172: }
173: initialized_ = false;
174: return 0;
175: }
177: template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
178: PETSC_NODISCARD static PoolType &default_pool_() noexcept
179: {
180: static PoolType pool;
181: return pool;
182: }
184: PETSC_NODISCARD static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
185: {
187: return 0;
188: }
190: public:
191: // All of these functions MUST be static in order to be callable from C, otherwise they
192: // get the implicit 'this' pointer tacked on
193: PETSC_CXX_COMPAT_DECL(PetscErrorCode destroy(PetscDeviceContext));
194: PETSC_CXX_COMPAT_DECL(PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType));
195: PETSC_CXX_COMPAT_DECL(PetscErrorCode setUp(PetscDeviceContext));
196: PETSC_CXX_COMPAT_DECL(PetscErrorCode query(PetscDeviceContext, PetscBool *));
197: PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext));
198: PETSC_CXX_COMPAT_DECL(PetscErrorCode synchronize(PetscDeviceContext));
199: template <typename Handle_t>
200: PETSC_CXX_COMPAT_DECL(PetscErrorCode getHandle(PetscDeviceContext, void *));
201: PETSC_CXX_COMPAT_DECL(PetscErrorCode beginTimer(PetscDeviceContext));
202: PETSC_CXX_COMPAT_DECL(PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *));
203: PETSC_CXX_COMPAT_DECL(PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **));
204: PETSC_CXX_COMPAT_DECL(PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **));
205: PETSC_CXX_COMPAT_DECL(PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode));
206: PETSC_CXX_COMPAT_DECL(PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t));
207: PETSC_CXX_COMPAT_DECL(PetscErrorCode createEvent(PetscDeviceContext, PetscEvent));
208: PETSC_CXX_COMPAT_DECL(PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent));
209: PETSC_CXX_COMPAT_DECL(PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent));
211: // not a PetscDeviceContext method, this registers the class
212: PETSC_CXX_COMPAT_DECL(PetscErrorCode initialize());
214: // clang-format off
215: const _DeviceContextOps ops = {
216: destroy,
217: changeStreamType,
218: setUp,
219: query,
220: waitForContext,
221: synchronize,
222: getHandle<blas_tag>,
223: getHandle<solver_tag>,
224: getHandle<stream_tag>,
225: beginTimer,
226: endTimer,
227: memAlloc,
228: memFree,
229: memCopy,
230: memSet,
231: createEvent,
232: recordEvent,
233: waitForEvent
234: };
235: // clang-format on
236: };
238: // not a PetscDeviceContext method, this initializes the CLASS
239: template <DeviceType T>
240: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::initialize())
241: {
242: if (PetscUnlikely(!initialized_)) {
243: cupmMemPool_t mempool;
244: uint64_t threshold = UINT64_MAX;
246: initialized_ = true;
247: cupmDeviceGetMemPool(&mempool, 0);
248: cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold);
249: blashandles_.fill(nullptr);
250: solverhandles_.fill(nullptr);
251: PetscRegisterFinalize(finalize_);
252: }
253: return 0;
254: }
256: template <DeviceType T>
257: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx))
258: {
259: if (const auto dci = impls_cast_(dctx)) {
260: dci->stream.destroy();
261: if (dci->event) cupm_fast_event_pool<T>().deallocate(std::move(dci->event));
262: if (dci->begin) cupmEventDestroy(dci->begin);
263: if (dci->end) cupmEventDestroy(dci->end);
264: delete dci;
265: dctx->data = nullptr;
266: }
267: return 0;
268: }
270: template <DeviceType T>
271: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype))
272: {
273: const auto dci = impls_cast_(dctx);
275: dci->stream.destroy();
276: // set these to null so they aren't usable until setup is called again
277: dci->blas = nullptr;
278: dci->solver = nullptr;
279: return 0;
280: }
282: template <DeviceType T>
283: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx))
284: {
285: const auto dci = impls_cast_(dctx);
286: auto &event = dci->event;
288: check_current_device_(dctx);
289: dci->stream.change_type(dctx->streamType);
290: if (!event) cupm_fast_event_pool<T>().allocate(&event);
291: #if PetscDefined(USE_DEBUG)
292: dci->timerInUse = PETSC_FALSE;
293: #endif
294: return 0;
295: }
297: template <DeviceType T>
298: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle))
299: {
300: check_current_device_(dctx);
301: switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
302: case cupmSuccess:
303: *idle = PETSC_TRUE;
304: break;
305: case cupmErrorNotReady:
306: *idle = PETSC_FALSE;
307: // reset the error
308: cerr = cupmGetLastError();
309: static_cast<void>(cerr);
310: break;
311: default:
312: cerr;
313: PetscUnreachable();
314: }
315: return 0;
316: }
318: template <DeviceType T>
319: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb))
320: {
321: const auto dcib = impls_cast_(dctxb);
322: const auto event = dcib->event;
324: check_current_device_(dctxa, dctxb);
325: cupmEventRecord(event, dcib->stream.get_stream());
326: cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0);
327: return 0;
328: }
330: template <DeviceType T>
331: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx))
332: {
333: auto idle = PETSC_TRUE;
335: query(dctx, &idle);
336: if (!idle) cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream());
337: return 0;
338: }
340: template <DeviceType T>
341: template <typename handle_t>
342: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle))
343: {
344: initialize_handle_(handle_t{}, dctx);
345: *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
346: return 0;
347: }
349: template <DeviceType T>
350: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx))
351: {
352: const auto dci = impls_cast_(dctx);
354: check_current_device_(dctx);
355: #if PetscDefined(USE_DEBUG)
357: dci->timerInUse = PETSC_TRUE;
358: #endif
359: if (!dci->begin) {
360: PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
361: cupmEventCreate(&dci->begin);
362: cupmEventCreate(&dci->end);
363: }
364: cupmEventRecord(dci->begin, dci->stream.get_stream());
365: return 0;
366: }
368: template <DeviceType T>
369: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed))
370: {
371: float gtime;
372: const auto dci = impls_cast_(dctx);
373: const auto end = dci->end;
375: check_current_device_(dctx);
376: #if PetscDefined(USE_DEBUG)
378: dci->timerInUse = PETSC_FALSE;
379: #endif
380: cupmEventRecord(end, dci->stream.get_stream());
381: cupmEventSynchronize(end);
382: cupmEventElapsedTime(>ime, dci->begin, end);
383: *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
384: return 0;
385: }
387: template <DeviceType T>
388: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest))
389: {
390: const auto &stream = impls_cast_(dctx)->stream;
392: check_current_device_(dctx);
393: check_memtype_(mtype, "allocating");
394: if (PetscMemTypeHost(mtype)) {
395: default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment);
396: } else {
397: default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment);
398: }
399: if (clear) cupmMemsetAsync(*dest, 0, n, stream.get_stream());
400: return 0;
401: }
403: template <DeviceType T>
404: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr))
405: {
406: const auto &stream = impls_cast_(dctx)->stream;
408: check_current_device_(dctx);
409: check_memtype_(mtype, "freeing");
410: if (!*ptr) return 0;
411: if (PetscMemTypeHost(mtype)) {
412: default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream);
413: // if ptr exists still exists the pool didn't own it
414: if (*ptr) {
415: auto registered = PETSC_FALSE, managed = PETSC_FALSE;
417: PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed);
418: if (registered) {
419: cupmFreeHost(*ptr);
420: } else if (managed) {
421: cupmFreeAsync(*ptr, stream.get_stream());
422: }
423: }
424: } else {
425: default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream);
426: // if ptr exists still exists the pool didn't own it
427: if (*ptr) cupmFreeAsync(*ptr, stream.get_stream());
428: }
429: return 0;
430: }
432: template <DeviceType T>
433: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode))
434: {
435: const auto stream = impls_cast_(dctx)->stream.get_stream();
437: // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
438: if (mode == PETSC_DEVICE_COPY_HTOH) {
439: const auto cerr = cupmStreamQuery(stream);
441: // yes this is faster
442: if (cerr == cupmSuccess) {
443: PetscMemcpy(dest, src, n);
444: return 0;
445: } else if (cerr == cupmErrorNotReady) {
446: auto PETSC_UNUSED unused = cupmGetLastError();
448: static_cast<void>(unused);
449: } else {
450: cerr;
451: }
452: }
453: cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream);
454: return 0;
455: }
457: template <DeviceType T>
458: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n))
459: {
460: check_current_device_(dctx);
461: check_memtype_(mtype, "zeroing");
462: cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream());
463: return 0;
464: }
466: template <DeviceType T>
467: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext dctx, PetscEvent event))
468: {
469: event->data = new event_type();
470: event->destroy = [](PetscEvent event) {
471: delete event_cast_(event);
472: event->data = nullptr;
473: return 0;
474: };
475: return 0;
476: }
478: template <DeviceType T>
479: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event))
480: {
481: impls_cast_(dctx)->stream.record_event(*event_cast_(event));
482: return 0;
483: }
485: template <DeviceType T>
486: PETSC_CXX_COMPAT_DEFN(PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event))
487: {
488: impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event));
489: return 0;
490: }
492: // initialize the static member variables
493: template <DeviceType T>
494: bool DeviceContext<T>::initialized_ = false;
496: template <DeviceType T>
497: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
499: template <DeviceType T>
500: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
502: } // namespace impl
504: // shorten this one up a bit (and instantiate the templates)
505: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
506: using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
508: // shorthand for what is an EXTREMELY long name
509: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
511: } // namespace cupm
513: } // namespace device
515: } // namespace Petsc
517: #endif // __cplusplus
519: #endif // PETSCDEVICECONTEXTCUDA_HPP