Actual source code: kinit.kokkos.cxx
1: #include <petsc/private/deviceimpl.h>
2: #include <petscpkg_version.h>
3: #include <Kokkos_Core.hpp>
5: PetscBool PetscKokkosInitialized = PETSC_FALSE;
7: PetscErrorCode PetscKokkosFinalize_Private(void)
8: {
9: Kokkos::finalize();
10: return 0;
11: }
13: PetscErrorCode PetscKokkosIsInitialized_Private(PetscBool *isInitialized)
14: {
15: *isInitialized = Kokkos::is_initialized() ? PETSC_TRUE : PETSC_FALSE;
16: return 0;
17: }
19: /* Initialize Kokkos if not yet */
20: PetscErrorCode PetscKokkosInitializeCheck(void)
21: {
22: if (!Kokkos::is_initialized()) {
23: #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
24: auto args = Kokkos::InitializationSettings();
25: #else
26: auto args = Kokkos::InitArguments{}; /* use default constructor */
27: #endif
29: #if (defined(KOKKOS_ENABLE_CUDA) && PetscDefined(HAVE_CUDA)) || (defined(KOKKOS_ENABLE_HIP) && PetscDefined(HAVE_HIP)) || (defined(KOKKOS_ENABLE_SYCL) && PetscDefined(HAVE_SYCL))
30: /* Kokkos does not support CUDA and HIP at the same time (but we do :)) */
31: PetscDeviceContext dctx;
33: PetscDeviceContextGetCurrentContext(&dctx);
34: #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
35: args.set_device_id(static_cast<int>(dctx->device->deviceId));
36: #else
37: PetscMPIIntCast(dctx->device->deviceId, &args.device_id);
38: #endif
39: #endif
41: #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
42: args.set_disable_warnings(!PetscDefined(HAVE_KOKKOS_INIT_WARNINGS));
43: #else
44: args.disable_warnings = !PetscDefined(HAVE_KOKKOS_INIT_WARNINGS);
45: #endif
47: /* To use PetscNumOMPThreads, one has to configure petsc --with-openmp.
48: Otherwise, let's keep the default value (-1) of args.num_threads.
49: */
50: #if defined(KOKKOS_ENABLE_OPENMP) && PetscDefined(HAVE_OPENMP)
51: #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
52: args.set_num_threads(PetscNumOMPThreads);
53: #else
54: args.num_threads = PetscNumOMPThreads;
55: #endif
56: #endif
58: Kokkos::initialize(args);
59: PetscBeganKokkos = PETSC_TRUE;
60: }
61: PetscKokkosInitialized = PETSC_TRUE;
62: return 0;
63: }