Actual source code: cupmblasinterface.hpp

  1: #ifndef PETSCCUPMBLASINTERFACE_HPP
  2: #define PETSCCUPMBLASINTERFACE_HPP

  4: #if defined(__cplusplus)
  5: #include <petsc/private/cupminterface.hpp>
  6: #include <petsc/private/petscadvancedmacros.h>

  8: namespace Petsc
  9: {

 11: namespace device
 12: {

 14: namespace cupm
 15: {

 17: namespace impl
 18: {

 20:   #define PetscCallCUPMBLAS(...) \
 21:     do { \
 22:       const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
 23:       if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 24:         if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 25:           SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
 26:                   "%s error %d (%s). Reports not initialized or alloc failed; " \
 27:                   "this indicates the GPU may have run out resources", \
 28:                   cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 29:         } \
 30:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 31:       } \
 32:     } while (0)

 34:   // given cupmBlas<T>axpy() then
 35:   // T = PETSC_CUPBLAS_FP_TYPE
 36:   // given cupmBlas<T><u>nrm2() then
 37:   // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
 38:   // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
 39:   #if PetscDefined(USE_COMPLEX)
 40:     #if PetscDefined(USE_REAL_SINGLE)
 41:       #define PETSC_CUPMBLAS_FP_TYPE_U       C
 42:       #define PETSC_CUPMBLAS_FP_TYPE_L       c
 43:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
 44:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
 45:     #elif PetscDefined(USE_REAL_DOUBLE)
 46:       #define PETSC_CUPMBLAS_FP_TYPE_U       Z
 47:       #define PETSC_CUPMBLAS_FP_TYPE_L       z
 48:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
 49:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
 50:     #endif
 51:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 52:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 53:   #else
 54:     #if PetscDefined(USE_REAL_SINGLE)
 55:       #define PETSC_CUPMBLAS_FP_TYPE_U S
 56:       #define PETSC_CUPMBLAS_FP_TYPE_L s
 57:     #elif PetscDefined(USE_REAL_DOUBLE)
 58:       #define PETSC_CUPMBLAS_FP_TYPE_U D
 59:       #define PETSC_CUPMBLAS_FP_TYPE_L d
 60:     #endif
 61:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 62:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 63:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
 64:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
 65:   #endif // USE_COMPLEX

 67:   #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
 68:     #error "Unsupported floating-point type for CUDA/HIP BLAS"
 69:   #endif

 71:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
 72:   // blas function whose return type does not match the input type
 73:   //
 74:   // input param:
 75:   // func - base suffix of the blas function, e.g. nrm2
 76:   //
 77:   // notes:
 78:   // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
 79:   // letter ("S" for real/complex single, "D" for real/complex double).
 80:   //
 81:   // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
 82:   // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
 83:   // single/double).
 84:   //
 85:   // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
 86:   // infuriatingly inconsistent...
 87:   //
 88:   // example usage:
 89:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
 90:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
 91:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
 92:   //
 93:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
 94:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
 95:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
 96:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)

 98:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
 99:   // because they are both extra special
100:   //
101:   // input param:
102:   // func - base suffix of the blas function, either amax or amin
103:   //
104:   // notes:
105:   // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
106:   // that's what it does.
107:   //
108:   // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
109:   // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
110:   // real double).
111:   //
112:   // example usage:
113:   // #define PETSC_CUPMBLAS_FP_TYPE_L s
114:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
115:   //
116:   // #define PETSC_CUPMBLAS_FP_TYPE_L z
117:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
118:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))

120:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
121:   // blas function name
122:   //
123:   // input param:
124:   // func - base suffix of the blas function, e.g. axpy, scal
125:   //
126:   // notes:
127:   // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
128:   // complex single, "Z" for complex double, "S" for real single, "D" for real double).
129:   //
130:   // example usage:
131:   // #define PETSC_CUPMBLAS_FP_TYPE S
132:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
133:   //
134:   // #define PETSC_CUPMBLAS_FP_TYPE Z
135:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
136:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)

138:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
139:   // one can provide both here
140:   //
141:   // input params:
142:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
143:   // IFPTYPE
144:   // our_suffix   - the suffix of the alias function
145:   // their_suffix - the suffix of the function being aliased
146:   //
147:   // notes:
148:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
149:   // prefix. requires any other specific definitions required by the specific builder macro to
150:   // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
151:   // function alias.
152:   //
153:   // example usage:
154:   // #define PETSC_CUPMBLAS_PREFIX  cublas
155:   // #define PETSC_CUPMBLAS_FP_TYPE C
156:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
157:   // template <typename... T>
158:   // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
159:   // {
160:   //   return cublasCdotc(std::forward<T>(args)...);
161:   // }
162:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
163:     PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))

165:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
166:   //
167:   // input params:
168:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
169:   // IFPTYPE
170:   // suffix       - the common suffix between CUDA and HIP of the alias function
171:   //
172:   // notes:
173:   // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
174:   // "our_prefix" and "their_prefix"
175:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)

177:   // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
178:   //
179:   // input params:
180:   // suffix - the common suffix between CUDA and HIP of the alias function
181:   //
182:   // notes:
183:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
184:   // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
185:   //
186:   // example usage:
187:   // #define PETSC_CUPMBLAS_PREFIX hipblas
188:   // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
189:   // template <typename... T>
190:   // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
191:   // {
192:   //   return hipblasCreate(std::forward<T>(args)...);
193:   // }
194:   #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))

196: template <DeviceType T>
197: struct BlasInterfaceBase : Interface<T> {
198:   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
199: };

201:   #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
202:     using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
203:     using base_type::cupmBlasName; \
204:     PETSC_CUPM_ALIAS_FUNCTION(cupmBlasGetErrorName, PetscConcat(PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName)) \
205:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)

207: template <DeviceType>
208: struct BlasInterfaceImpl;

210:   #if PetscDefined(HAVE_CUDA)
211:     #define PETSC_CUPMBLAS_PREFIX         cublas
212:     #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
213:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
214:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
215:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
216: template <>
217: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
218:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);

220:   // typedefs
221:   using cupmBlasHandle_t      = cublasHandle_t;
222:   using cupmBlasError_t       = cublasStatus_t;
223:   using cupmBlasInt_t         = int;
224:   using cupmSolverHandle_t    = cusolverDnHandle_t;
225:   using cupmSolverError_t     = cusolverStatus_t;
226:   using cupmBlasPointerMode_t = cublasPointerMode_t;

228:   // values
229:   static const auto CUPMBLAS_STATUS_SUCCESS         = CUBLAS_STATUS_SUCCESS;
230:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
231:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = CUBLAS_STATUS_ALLOC_FAILED;
232:   static const auto CUPMBLAS_POINTER_MODE_HOST      = CUBLAS_POINTER_MODE_HOST;
233:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = CUBLAS_POINTER_MODE_DEVICE;

235:   // utility functions
236:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
237:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
238:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
239:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
240:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
241:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

243:   // level 1 BLAS
244:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
245:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
246:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
247:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
248:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
249:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
250:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
251:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

253:   // level 2 BLAS
254:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

256:   // level 3 BLAS
257:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

259:   // BLAS extensions
260:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

262:   PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
263:   {
264:     if (handle) return 0;
265:     for (auto i = 0; i < 3; ++i) {
266:       const auto cerr = cusolverDnCreate(&handle);
267:       if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
268:       if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) cerr;
269:       if (i < 2) {
270:         PetscSleep(3);
271:         continue;
272:       }
274:     }
275:     return 0;
276:   }

278:   PETSC_NODISCARD static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
279:   {
280:     cupmStream_t cupmStream;

282:     cusolverDnGetStream(handle, &cupmStream);
283:     if (cupmStream != stream) cusolverDnSetStream(handle, stream);
284:     return 0;
285:   }

287:   PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
288:   {
289:     if (handle) {
290:       cusolverDnDestroy(handle);
291:       handle = nullptr;
292:     }
293:     return 0;
294:   }
295: };
296:     #undef PETSC_CUPMBLAS_PREFIX
297:     #undef PETSC_CUPMBLAS_PREFIX_U
298:     #undef PETSC_CUPMBLAS_FP_TYPE
299:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
300:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
301:   #endif // PetscDefined(HAVE_CUDA)

303:   #if PetscDefined(HAVE_HIP)
304:     #define PETSC_CUPMBLAS_PREFIX         hipblas
305:     #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
306:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
307:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
308:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
309: template <>
310: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
311:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);

313:   // typedefs
314:   using cupmBlasHandle_t      = hipblasHandle_t;
315:   using cupmBlasError_t       = hipblasStatus_t;
316:   using cupmBlasInt_t         = int; // rocblas will have its own
317:   using cupmSolverHandle_t    = hipsolverHandle_t;
318:   using cupmSolverError_t     = hipsolverStatus_t;
319:   using cupmBlasPointerMode_t = hipblasPointerMode_t;

321:   // values
322:   static const auto CUPMBLAS_STATUS_SUCCESS         = HIPBLAS_STATUS_SUCCESS;
323:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
324:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = HIPBLAS_STATUS_ALLOC_FAILED;
325:   static const auto CUPMBLAS_POINTER_MODE_HOST      = HIPBLAS_POINTER_MODE_HOST;
326:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = HIPBLAS_POINTER_MODE_DEVICE;

328:   // utility functions
329:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
330:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
331:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
332:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
333:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
334:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

336:   // level 1 BLAS
337:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
338:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
339:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
340:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
341:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
342:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
343:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
344:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

346:   // level 2 BLAS
347:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

349:   // level 3 BLAS
350:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

352:   // BLAS extensions
353:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

355:   PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
356:   {
357:     if (!handle) hipsolverCreate(&handle);
358:     return 0;
359:   }

361:   PETSC_NODISCARD static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
362:   {
363:     hipsolverSetStream(handle, stream);
364:     return 0;
365:   }

367:   PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
368:   {
369:     if (handle) {
370:       hipsolverDestroy(handle);
371:       handle = nullptr;
372:     }
373:     return 0;
374:   }
375: };
376:     #undef PETSC_CUPMBLAS_PREFIX
377:     #undef PETSC_CUPMBLAS_PREFIX_U
378:     #undef PETSC_CUPMBLAS_FP_TYPE
379:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
380:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
381:   #endif // PetscDefined(HAVE_HIP)

383:   #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER

385:   #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
386:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
387:     using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
388:     /* introspection */ \
389:     using base_name::cupmBlasName; \
390:     using base_name::cupmBlasGetErrorName; \
391:     /* types */ \
392:     using cupmBlasHandle_t      = typename base_name::cupmBlasHandle_t; \
393:     using cupmBlasError_t       = typename base_name::cupmBlasError_t; \
394:     using cupmBlasInt_t         = typename base_name::cupmBlasInt_t; \
395:     using cupmSolverHandle_t    = typename base_name::cupmSolverHandle_t; \
396:     using cupmSolverError_t     = typename base_name::cupmSolverError_t; \
397:     using cupmBlasPointerMode_t = typename base_name::cupmBlasPointerMode_t; \
398:     /* values */ \
399:     using base_name::CUPMBLAS_STATUS_SUCCESS; \
400:     using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
401:     using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
402:     using base_name::CUPMBLAS_POINTER_MODE_HOST; \
403:     using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
404:     /* utility functions */ \
405:     using base_name::cupmBlasCreate; \
406:     using base_name::cupmBlasDestroy; \
407:     using base_name::cupmBlasGetStream; \
408:     using base_name::cupmBlasSetStream; \
409:     using base_name::cupmBlasGetPointerMode; \
410:     using base_name::cupmBlasSetPointerMode; \
411:     /* level 1 BLAS */ \
412:     using base_name::cupmBlasXaxpy; \
413:     using base_name::cupmBlasXscal; \
414:     using base_name::cupmBlasXdot; \
415:     using base_name::cupmBlasXdotu; \
416:     using base_name::cupmBlasXswap; \
417:     using base_name::cupmBlasXnrm2; \
418:     using base_name::cupmBlasXamax; \
419:     using base_name::cupmBlasXasum; \
420:     /* level 2 BLAS */ \
421:     using base_name::cupmBlasXgemv; \
422:     /* level 3 BLAS */ \
423:     using base_name::cupmBlasXgemm; \
424:     /* BLAS extensions */ \
425:     using base_name::cupmBlasXgeam

427: // The actual interface class
428: template <DeviceType T>
429: struct BlasInterface : BlasInterfaceImpl<T> {
430:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);

432:   PETSC_NODISCARD static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
433:   {
434:     auto mtype = PETSC_MEMTYPE_HOST;

436:     PetscCUPMGetMemType(ptr, &mtype);
437:     cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST);
438:     return 0;
439:   }
440: };

442:   #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
443:     PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
444:     using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
445:     using base_name::PetscCUPMBlasSetPointerModeFromPointer

447:   #if PetscDefined(HAVE_CUDA)
448: extern template struct BlasInterface<DeviceType::CUDA>;
449:   #endif

451:   #if PetscDefined(HAVE_HIP)
452: extern template struct BlasInterface<DeviceType::HIP>;
453:   #endif

455: } // namespace impl

457: } // namespace cupm

459: } // namespace device

461: } // namespace Petsc

463: #endif // defined(__cplusplus)

465: #endif // PETSCCUPMBLASINTERFACE_HPP