Actual source code: cupmblasinterface.hpp
1: #ifndef PETSCCUPMBLASINTERFACE_HPP
2: #define PETSCCUPMBLASINTERFACE_HPP
4: #if defined(__cplusplus)
5: #include <petsc/private/cupminterface.hpp>
6: #include <petsc/private/petscadvancedmacros.h>
8: namespace Petsc
9: {
11: namespace device
12: {
14: namespace cupm
15: {
17: namespace impl
18: {
20: #define PetscCallCUPMBLAS(...) \
21: do { \
22: const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
23: if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
24: if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
25: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
26: "%s error %d (%s). Reports not initialized or alloc failed; " \
27: "this indicates the GPU may have run out resources", \
28: cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
29: } \
30: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
31: } \
32: } while (0)
34: // given cupmBlas<T>axpy() then
35: // T = PETSC_CUPBLAS_FP_TYPE
36: // given cupmBlas<T><u>nrm2() then
37: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
38: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
39: #if PetscDefined(USE_COMPLEX)
40: #if PetscDefined(USE_REAL_SINGLE)
41: #define PETSC_CUPMBLAS_FP_TYPE_U C
42: #define PETSC_CUPMBLAS_FP_TYPE_L c
43: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
44: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
45: #elif PetscDefined(USE_REAL_DOUBLE)
46: #define PETSC_CUPMBLAS_FP_TYPE_U Z
47: #define PETSC_CUPMBLAS_FP_TYPE_L z
48: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
49: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
50: #endif
51: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
52: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
53: #else
54: #if PetscDefined(USE_REAL_SINGLE)
55: #define PETSC_CUPMBLAS_FP_TYPE_U S
56: #define PETSC_CUPMBLAS_FP_TYPE_L s
57: #elif PetscDefined(USE_REAL_DOUBLE)
58: #define PETSC_CUPMBLAS_FP_TYPE_U D
59: #define PETSC_CUPMBLAS_FP_TYPE_L d
60: #endif
61: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
62: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
63: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
64: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
65: #endif // USE_COMPLEX
67: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
68: #error "Unsupported floating-point type for CUDA/HIP BLAS"
69: #endif
71: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
72: // blas function whose return type does not match the input type
73: //
74: // input param:
75: // func - base suffix of the blas function, e.g. nrm2
76: //
77: // notes:
78: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
79: // letter ("S" for real/complex single, "D" for real/complex double).
80: //
81: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
82: // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
83: // single/double).
84: //
85: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
86: // infuriatingly inconsistent...
87: //
88: // example usage:
89: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S
90: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
91: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
92: //
93: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D
94: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
95: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
96: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
98: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
99: // because they are both extra special
100: //
101: // input param:
102: // func - base suffix of the blas function, either amax or amin
103: //
104: // notes:
105: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
106: // that's what it does.
107: //
108: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
109: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
110: // real double).
111: //
112: // example usage:
113: // #define PETSC_CUPMBLAS_FP_TYPE_L s
114: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
115: //
116: // #define PETSC_CUPMBLAS_FP_TYPE_L z
117: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
118: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
120: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
121: // blas function name
122: //
123: // input param:
124: // func - base suffix of the blas function, e.g. axpy, scal
125: //
126: // notes:
127: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
128: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
129: //
130: // example usage:
131: // #define PETSC_CUPMBLAS_FP_TYPE S
132: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
133: //
134: // #define PETSC_CUPMBLAS_FP_TYPE Z
135: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
136: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
138: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
139: // one can provide both here
140: //
141: // input params:
142: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
143: // IFPTYPE
144: // our_suffix - the suffix of the alias function
145: // their_suffix - the suffix of the function being aliased
146: //
147: // notes:
148: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
149: // prefix. requires any other specific definitions required by the specific builder macro to
150: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
151: // function alias.
152: //
153: // example usage:
154: // #define PETSC_CUPMBLAS_PREFIX cublas
155: // #define PETSC_CUPMBLAS_FP_TYPE C
156: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
157: // template <typename... T>
158: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
159: // {
160: // return cublasCdotc(std::forward<T>(args)...);
161: // }
162: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
163: PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))
165: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
166: //
167: // input params:
168: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
169: // IFPTYPE
170: // suffix - the common suffix between CUDA and HIP of the alias function
171: //
172: // notes:
173: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
174: // "our_prefix" and "their_prefix"
175: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
177: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
178: //
179: // input params:
180: // suffix - the common suffix between CUDA and HIP of the alias function
181: //
182: // notes:
183: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
184: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
185: //
186: // example usage:
187: // #define PETSC_CUPMBLAS_PREFIX hipblas
188: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
189: // template <typename... T>
190: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
191: // {
192: // return hipblasCreate(std::forward<T>(args)...);
193: // }
194: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))
196: template <DeviceType T>
197: struct BlasInterfaceBase : Interface<T> {
198: PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
199: };
201: #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
202: using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
203: using base_type::cupmBlasName; \
204: PETSC_CUPM_ALIAS_FUNCTION(cupmBlasGetErrorName, PetscConcat(PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName)) \
205: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)
207: template <DeviceType>
208: struct BlasInterfaceImpl;
210: #if PetscDefined(HAVE_CUDA)
211: #define PETSC_CUPMBLAS_PREFIX cublas
212: #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
213: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
214: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
215: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
216: template <>
217: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
218: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);
220: // typedefs
221: using cupmBlasHandle_t = cublasHandle_t;
222: using cupmBlasError_t = cublasStatus_t;
223: using cupmBlasInt_t = int;
224: using cupmSolverHandle_t = cusolverDnHandle_t;
225: using cupmSolverError_t = cusolverStatus_t;
226: using cupmBlasPointerMode_t = cublasPointerMode_t;
228: // values
229: static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS;
230: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
231: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = CUBLAS_STATUS_ALLOC_FAILED;
232: static const auto CUPMBLAS_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST;
233: static const auto CUPMBLAS_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE;
235: // utility functions
236: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
237: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
238: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
239: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
240: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
241: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
243: // level 1 BLAS
244: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
245: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
246: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
247: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
248: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
249: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
250: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
251: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
253: // level 2 BLAS
254: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
256: // level 3 BLAS
257: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
259: // BLAS extensions
260: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
262: PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
263: {
264: if (handle) return 0;
265: for (auto i = 0; i < 3; ++i) {
266: const auto cerr = cusolverDnCreate(&handle);
267: if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
268: if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) cerr;
269: if (i < 2) {
270: PetscSleep(3);
271: continue;
272: }
274: }
275: return 0;
276: }
278: PETSC_NODISCARD static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
279: {
280: cupmStream_t cupmStream;
282: cusolverDnGetStream(handle, &cupmStream);
283: if (cupmStream != stream) cusolverDnSetStream(handle, stream);
284: return 0;
285: }
287: PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
288: {
289: if (handle) {
290: cusolverDnDestroy(handle);
291: handle = nullptr;
292: }
293: return 0;
294: }
295: };
296: #undef PETSC_CUPMBLAS_PREFIX
297: #undef PETSC_CUPMBLAS_PREFIX_U
298: #undef PETSC_CUPMBLAS_FP_TYPE
299: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
300: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
301: #endif // PetscDefined(HAVE_CUDA)
303: #if PetscDefined(HAVE_HIP)
304: #define PETSC_CUPMBLAS_PREFIX hipblas
305: #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
306: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
307: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
308: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
309: template <>
310: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
311: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);
313: // typedefs
314: using cupmBlasHandle_t = hipblasHandle_t;
315: using cupmBlasError_t = hipblasStatus_t;
316: using cupmBlasInt_t = int; // rocblas will have its own
317: using cupmSolverHandle_t = hipsolverHandle_t;
318: using cupmSolverError_t = hipsolverStatus_t;
319: using cupmBlasPointerMode_t = hipblasPointerMode_t;
321: // values
322: static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS;
323: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
324: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = HIPBLAS_STATUS_ALLOC_FAILED;
325: static const auto CUPMBLAS_POINTER_MODE_HOST = HIPBLAS_POINTER_MODE_HOST;
326: static const auto CUPMBLAS_POINTER_MODE_DEVICE = HIPBLAS_POINTER_MODE_DEVICE;
328: // utility functions
329: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
330: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
331: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
332: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
333: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
334: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
336: // level 1 BLAS
337: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
338: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
339: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
340: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
341: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
342: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
343: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
344: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
346: // level 2 BLAS
347: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
349: // level 3 BLAS
350: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
352: // BLAS extensions
353: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
355: PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
356: {
357: if (!handle) hipsolverCreate(&handle);
358: return 0;
359: }
361: PETSC_NODISCARD static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
362: {
363: hipsolverSetStream(handle, stream);
364: return 0;
365: }
367: PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
368: {
369: if (handle) {
370: hipsolverDestroy(handle);
371: handle = nullptr;
372: }
373: return 0;
374: }
375: };
376: #undef PETSC_CUPMBLAS_PREFIX
377: #undef PETSC_CUPMBLAS_PREFIX_U
378: #undef PETSC_CUPMBLAS_FP_TYPE
379: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
380: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
381: #endif // PetscDefined(HAVE_HIP)
383: #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER
385: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
386: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
387: using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
388: /* introspection */ \
389: using base_name::cupmBlasName; \
390: using base_name::cupmBlasGetErrorName; \
391: /* types */ \
392: using cupmBlasHandle_t = typename base_name::cupmBlasHandle_t; \
393: using cupmBlasError_t = typename base_name::cupmBlasError_t; \
394: using cupmBlasInt_t = typename base_name::cupmBlasInt_t; \
395: using cupmSolverHandle_t = typename base_name::cupmSolverHandle_t; \
396: using cupmSolverError_t = typename base_name::cupmSolverError_t; \
397: using cupmBlasPointerMode_t = typename base_name::cupmBlasPointerMode_t; \
398: /* values */ \
399: using base_name::CUPMBLAS_STATUS_SUCCESS; \
400: using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
401: using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
402: using base_name::CUPMBLAS_POINTER_MODE_HOST; \
403: using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
404: /* utility functions */ \
405: using base_name::cupmBlasCreate; \
406: using base_name::cupmBlasDestroy; \
407: using base_name::cupmBlasGetStream; \
408: using base_name::cupmBlasSetStream; \
409: using base_name::cupmBlasGetPointerMode; \
410: using base_name::cupmBlasSetPointerMode; \
411: /* level 1 BLAS */ \
412: using base_name::cupmBlasXaxpy; \
413: using base_name::cupmBlasXscal; \
414: using base_name::cupmBlasXdot; \
415: using base_name::cupmBlasXdotu; \
416: using base_name::cupmBlasXswap; \
417: using base_name::cupmBlasXnrm2; \
418: using base_name::cupmBlasXamax; \
419: using base_name::cupmBlasXasum; \
420: /* level 2 BLAS */ \
421: using base_name::cupmBlasXgemv; \
422: /* level 3 BLAS */ \
423: using base_name::cupmBlasXgemm; \
424: /* BLAS extensions */ \
425: using base_name::cupmBlasXgeam
427: // The actual interface class
428: template <DeviceType T>
429: struct BlasInterface : BlasInterfaceImpl<T> {
430: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);
432: PETSC_NODISCARD static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
433: {
434: auto mtype = PETSC_MEMTYPE_HOST;
436: PetscCUPMGetMemType(ptr, &mtype);
437: cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST);
438: return 0;
439: }
440: };
442: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
443: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
444: using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
445: using base_name::PetscCUPMBlasSetPointerModeFromPointer
447: #if PetscDefined(HAVE_CUDA)
448: extern template struct BlasInterface<DeviceType::CUDA>;
449: #endif
451: #if PetscDefined(HAVE_HIP)
452: extern template struct BlasInterface<DeviceType::HIP>;
453: #endif
455: } // namespace impl
457: } // namespace cupm
459: } // namespace device
461: } // namespace Petsc
463: #endif // defined(__cplusplus)
465: #endif // PETSCCUPMBLASINTERFACE_HPP