Actual source code: bjkokkos.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <petsc/private/deviceimpl.h>
  3: #include <petsc/private/pcimpl.h>
  4: #include <petsc/private/kspimpl.h>
  5: #include <petscksp.h>
  6: #include "petscsection.h"
  7: #include <petscdmcomposite.h>
  8: #include "Kokkos_Core.hpp"

 10: #include <../src/mat/impls/aij/seq/aij.h>
 11: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>

 13: #if defined(PETSC_HAVE_CUDA)
 14:   #include <nvToolsExt.h>
 15: #endif

 17: #include <petscdevice_cupm.h>

 19: #define PCBJKOKKOS_SHARED_LEVEL 0 // 0 is shared, 1 is global
 20: #define PCBJKOKKOS_VEC_SIZE     16
 21: #define PCBJKOKKOS_TEAM_SIZE    16

 23: #define PCBJKOKKOS_VERBOSE_LEVEL 2

 25: typedef Kokkos::DefaultExecutionSpace exec_space;
 26: using layout           = Kokkos::LayoutRight;
 27: using IntView          = Kokkos::View<PetscInt **, layout, exec_space>;
 28: using AMatrixValueView = const Kokkos::View<PetscScalar **, layout, exec_space>;
 29: using XYType           = const Kokkos::View<PetscScalar **, layout, exec_space>;

 31: typedef enum {
 32:   BATCH_KSP_BICG_IDX,
 33:   BATCH_KSP_TFQMR_IDX,
 34:   BATCH_KSP_GMRES_IDX,
 35:   NUM_BATCH_TYPES
 36: } KSPIndex;
 37: typedef struct {
 38:   Vec                                               vec_diag;
 39:   PetscInt                                          nBlocks; /* total number of blocks */
 40:   PetscInt                                          n;       // cache host version of d_bid_eqOffset_k[nBlocks]
 41:   KSP                                               ksp;     // Used just for options. Should have one for each block
 42:   Kokkos::View<PetscInt *, Kokkos::LayoutRight>    *d_bid_eqOffset_k;
 43:   Kokkos::View<PetscScalar *, Kokkos::LayoutRight> *d_idiag_k;
 44:   Kokkos::View<PetscInt *>                         *d_isrow_k;
 45:   Kokkos::View<PetscInt *>                         *d_isicol_k;
 46:   KSPIndex                                          ksp_type_idx;
 47:   PetscInt                                          nwork;
 48:   PetscInt                                          const_block_size; // used to decide to use shared memory for work vectors
 49:   PetscInt                                         *dm_Nf;            // Number of fields in each DM
 50:   PetscInt                                          num_dms;
 51:   // diagnostics
 52:   PetscBool reason;
 53:   PetscBool monitor;
 54:   PetscInt  batch_target;
 55:   PetscInt  nsolves_team;
 56:   PetscInt  max_nits;
 57:   // caches
 58:   IntView          *rowOffsets;
 59:   IntView          *colIndices;
 60:   XYType           *batch_b;
 61:   XYType           *batch_x;
 62:   AMatrixValueView *batch_values;
 63: } PC_PCBJKOKKOS;

 65: #if defined(PETSC_HAVE_KOKKOS_KERNELS_GMRES)
 66:   #include <fstream>

 68:   #include "Kokkos_Timer.hpp"
 69:   #include "Kokkos_Random.hpp"
 70:   #include "Kokkos_UnorderedMap.hpp"
 71:   #include "Kokkos_Sort.hpp"

 73:   /// KokkosKernels headers
 74:   #include "KokkosBatched_Util.hpp"
 75:   #include "KokkosBatched_Vector.hpp"

 77:   #include <Kokkos_ArithTraits.hpp>
 78:   #include <KokkosBatched_Util.hpp>
 79:   #include <KokkosBatched_Vector.hpp>
 80:   #include <KokkosBatched_Copy_Decl.hpp>
 81:   #include <KokkosBatched_Copy_Impl.hpp>
 82:   #include <KokkosBatched_AddRadial_Decl.hpp>
 83:   #include <KokkosBatched_AddRadial_Impl.hpp>
 84:   #include <KokkosBatched_Gemm_Decl.hpp>
 85:   #include <KokkosBatched_Gemm_Serial_Impl.hpp>
 86:   #include <KokkosBatched_Gemm_Team_Impl.hpp>
 87:   #include <KokkosBatched_Gemv_Decl.hpp>
 88:   #include <KokkosBatched_Gemv_Serial_Impl.hpp>
 89:   #include <KokkosBatched_Gemv_Team_Impl.hpp>
 90:   #include <KokkosBatched_Trsm_Decl.hpp>
 91:   #include <KokkosBatched_Trsm_Serial_Impl.hpp>
 92:   #include <KokkosBatched_Trsm_Team_Impl.hpp>
 93:   #include <KokkosBatched_Trsv_Decl.hpp>
 94:   #include <KokkosBatched_Trsv_Serial_Impl.hpp>
 95:   #include <KokkosBatched_Trsv_Team_Impl.hpp>
 96:   #include <KokkosBatched_LU_Decl.hpp>
 97:   #include <KokkosBatched_LU_Serial_Impl.hpp>
 98:   #include <KokkosBatched_LU_Team_Impl.hpp>
 99:   #include <KokkosSparse_CrsMatrix.hpp>
100:   #include "KokkosBatched_Spmv.hpp"
101:   #include "KokkosBatched_CrsMatrix.hpp"
102:   #include "KokkosBatched_Krylov_Handle.hpp"
103:   #include "KokkosBatched_GMRES.hpp"
104:   #include "KokkosBatched_JacobiPrec.hpp"

106: template <typename DeviceType, typename ValuesViewType, typename IntView, typename VectorViewType, typename KrylovHandleType>
107: struct Functor_TestBatchedTeamVectorGMRES {
108:   const ValuesViewType _D;
109:   const ValuesViewType _diag;
110:   const IntView        _r;
111:   const IntView        _c;
112:   const VectorViewType _X;
113:   const VectorViewType _B;
114:   const int            _N_team, _team_size, _vector_length;
115:   const int            _N_iteration;
116:   const double         _tol;
117:   const int            _ortho_strategy;
118:   const int            _scratch_pad_level;
119:   KrylovHandleType     _handle;

121:   KOKKOS_INLINE_FUNCTION
122:   Functor_TestBatchedTeamVectorGMRES(const ValuesViewType &D, const IntView &r, const IntView &c, const VectorViewType &X, const VectorViewType &B, const int N_team, const int team_size, const int vector_length, const int N_iteration, const double tol, const int ortho_strategy, const int scratch_pad_level, KrylovHandleType &handle) :
123:     _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team), _team_size(team_size), _vector_length(vector_length), _N_iteration(N_iteration), _tol(tol), _ortho_strategy(ortho_strategy), _scratch_pad_level(scratch_pad_level), _handle(handle)
124:   {
125:   }

127:   KOKKOS_INLINE_FUNCTION
128:   Functor_TestBatchedTeamVectorGMRES(const ValuesViewType &D, const ValuesViewType &diag, const IntView &r, const IntView &c, const VectorViewType &X, const VectorViewType &B, const int N_team, const int team_size, const int vector_length, const int N_iteration, const double tol, int ortho_strategy, const int scratch_pad_level, KrylovHandleType &handle) :
129:     _D(D), _diag(diag), _r(r), _c(c), _X(X), _B(B), _N_team(N_team), _team_size(team_size), _vector_length(vector_length), _N_iteration(N_iteration), _tol(tol), _ortho_strategy(ortho_strategy), _scratch_pad_level(scratch_pad_level), _handle(handle)
130:   {
131:   }

133:   template <typename MemberType>
134:   KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const
135:   {
136:     const int first_matrix = static_cast<int>(member.league_rank()) * _N_team;
137:     const int N            = _D.extent(0);
138:     const int last_matrix  = (static_cast<int>(member.league_rank() + 1) * _N_team < N ? static_cast<int>(member.league_rank() + 1) * _N_team : N);
139:     const int graphID      = static_cast<int>(member.league_rank());
140:     using TeamVectorCopy1D = KokkosBatched::TeamVectorCopy<MemberType, KokkosBatched::Trans::NoTranspose, 1>;

142:     auto d                         = Kokkos::subview(_D, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
143:     auto x                         = Kokkos::subview(_X, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
144:     auto b                         = Kokkos::subview(_B, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
145:     using ScratchPadIntViewType    = Kokkos::View<typename IntView::non_const_value_type *, typename IntView::array_layout, typename IntView::execution_space::scratch_memory_space>;
146:     using ScratchPadValuesViewType = Kokkos::View<typename ValuesViewType::non_const_value_type **, typename ValuesViewType::array_layout, typename ValuesViewType::execution_space::scratch_memory_space>;

148:     using Operator = KokkosBatched::CrsMatrix<ValuesViewType, ScratchPadIntViewType>;
149:     ScratchPadIntViewType r(member.team_scratch(1), _r.extent(1));
150:     ScratchPadIntViewType c(member.team_scratch(1), _c.extent(1));

152:     TeamVectorCopy1D::invoke(member, Kokkos::subview(_r, graphID, Kokkos::ALL), r);
153:     TeamVectorCopy1D::invoke(member, Kokkos::subview(_c, graphID, Kokkos::ALL), c);
154:     Operator A(d, r, c);

156:     ScratchPadValuesViewType diag(member.team_scratch(1), last_matrix - first_matrix, _diag.extent(1));
157:     using PrecOperator = KokkosBatched::JacobiPrec<ScratchPadValuesViewType>;

159:     KokkosBatched::TeamVectorCopy<MemberType>::invoke(member, Kokkos::subview(_diag, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL), diag);
160:     PrecOperator P(diag);
161:     P.setComputedInverse();

163:     KokkosBatched::TeamVectorGMRES<MemberType>::template invoke<Operator, VectorViewType, PrecOperator, KrylovHandleType>(member, A, b, x, P, _handle);
164:   }
165:   inline double run(PC pc)
166:   {
167:     typedef typename ValuesViewType::value_type value_type;
168:     std::string                                 name("KokkosBatched::Test::TeamVectorGMRES");
169:     Kokkos::Timer                               timer;
170:     Kokkos::Profiling::pushRegion(name.c_str());

172:     Kokkos::TeamPolicy<DeviceType> auto_policy(ceil(1. * _D.extent(0) / _N_team), Kokkos::AUTO(), Kokkos::AUTO());
173:     Kokkos::TeamPolicy<DeviceType> tuned_policy(ceil(1. * _D.extent(0) / _N_team), _team_size, _vector_length);
174:     Kokkos::TeamPolicy<DeviceType> policy;

176:     if (_team_size < 1) policy = auto_policy;
177:     else policy = tuned_policy;

179:     _handle.set_max_iteration(_N_iteration);
180:     _handle.set_tolerance(_tol);
181:     _handle.set_ortho_strategy(_ortho_strategy);
182:     _handle.set_scratch_pad_level(_scratch_pad_level);
183:     _handle.set_compute_last_residual(true);

185:     int maximum_iteration = _handle.get_max_iteration();

187:     using ScalarType = typename ValuesViewType::non_const_value_type;
188:     using Layout     = typename ValuesViewType::array_layout;
189:     using EXSP       = typename ValuesViewType::execution_space;

191:     using MagnitudeType = typename Kokkos::Details::ArithTraits<ScalarType>::mag_type;

193:     using ViewType1D    = Kokkos::View<MagnitudeType *, Layout, EXSP>;
194:     using ViewType2D    = Kokkos::View<ScalarType **, Layout, EXSP>;
195:     using ViewType3D    = Kokkos::View<ScalarType ***, Layout, EXSP>;
196:     using IntViewType1D = Kokkos::View<PetscInt *, Layout, EXSP>;

198:     size_t bytes_1D      = ViewType2D::shmem_size(_N_team, 1);
199:     size_t bytes_row_ptr = IntViewType1D::shmem_size(_r.extent(1));
200:     size_t bytes_col_idc = IntViewType1D::shmem_size(_c.extent(1));
201:     size_t bytes_2D_1    = ViewType2D::shmem_size(_N_team, _X.extent(1));
202:     size_t bytes_2D_2    = ViewType2D::shmem_size(_N_team, maximum_iteration + 1);

204:     size_t bytes_diag = bytes_2D_1;
205:     size_t bytes_tmp  = 2 * bytes_2D_1 + 2 * bytes_1D + bytes_2D_2;

207:     policy.set_scratch_size(0, Kokkos::PerTeam(bytes_tmp));
208:     policy.set_scratch_size(1, Kokkos::PerTeam(bytes_col_idc + bytes_row_ptr + bytes_diag));
209:     PetscInfo(pc, "%d scratch memory(0) = %d + %d + %d bytes_diag=%d; %d scratch memory(1); %d maximum_iterations\n", (int)(bytes_tmp), 2 * (int)bytes_2D_1, 2 * (int)bytes_1D, (int)bytes_2D_2, (int)bytes_diag, (int)(bytes_row_ptr + bytes_col_idc + bytes_diag), (int)maximum_iteration);
210:     exec_space().fence();
211:     timer.reset();
212:     Kokkos::parallel_for(name.c_str(), policy, *this);
213:     exec_space().fence();
214:     double sec = timer.seconds();

216:     return sec;
217:   }
218: };
219: #endif // KK GMRES

221: typedef Kokkos::TeamPolicy<>::member_type team_member;

223: static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
224: {
225:   const char    *prefix;
226:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
227:   DM             dm;

229:   KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp);
230:   KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure);
231:   PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1);
232:   PCGetOptionsPrefix(pc, &prefix);
233:   KSPSetOptionsPrefix(jac->ksp, prefix);
234:   KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_");
235:   PCGetDM(pc, &dm);
236:   if (dm) {
237:     KSPSetDM(jac->ksp, dm);
238:     KSPSetDMActive(jac->ksp, PETSC_FALSE);
239:   }
240:   jac->reason       = PETSC_FALSE;
241:   jac->monitor      = PETSC_FALSE;
242:   jac->batch_target = -1;
243:   jac->nsolves_team = 1;
244:   jac->ksp->max_it  = 50; // this is realy for GMRES w/o restarts
245:   return 0;
246: }

248: // y <-- Ax
249: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
250: {
251:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
252:     int                rowa = ic[rowb];
253:     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
254:     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa];
255:     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
256:     PetscScalar        sum;
257:     Kokkos::parallel_reduce(
258:       Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
259:     Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
260:   });
261:   team.team_barrier();
262:   return 0;
263: }

265: // temp buffer per thread with reduction at end?
266: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
267: {
268:   Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
269:   team.team_barrier();
270:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
271:     int                rowa = ic[rowb];
272:     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
273:     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa];
274:     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
275:     const PetscScalar  xx   = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
276:     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
277:       PetscScalar val = aa[i] * xx;
278:       Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
279:     });
280:   });
281:   team.team_barrier();
282:   return 0;
283: }

285: typedef struct Batch_MetaData_TAG {
286:   PetscInt           flops;
287:   PetscInt           its;
288:   KSPConvergedReason reason;
289: } Batch_MetaData;

291: // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
292: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
293: {
294:   using Kokkos::parallel_for;
295:   using Kokkos::parallel_reduce;
296:   int                Nblk = end - start, i, m, stride = stride_shared, idx = 0;
297:   PetscReal          dp, dpold, w, dpest, tau, psi, cm, r0;
298:   const PetscScalar *Diag = &glb_idiag[start];
299:   PetscScalar       *ptr  = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;

301:   if (idx++ == nShareVec) {
302:     ptr    = work_space_global;
303:     stride = stride_global;
304:   }
305:   PetscScalar *XX = ptr;
306:   ptr += stride;
307:   if (idx++ == nShareVec) {
308:     ptr    = work_space_global;
309:     stride = stride_global;
310:   }
311:   PetscScalar *R = ptr;
312:   ptr += stride;
313:   if (idx++ == nShareVec) {
314:     ptr    = work_space_global;
315:     stride = stride_global;
316:   }
317:   PetscScalar *RP = ptr;
318:   ptr += stride;
319:   if (idx++ == nShareVec) {
320:     ptr    = work_space_global;
321:     stride = stride_global;
322:   }
323:   PetscScalar *V = ptr;
324:   ptr += stride;
325:   if (idx++ == nShareVec) {
326:     ptr    = work_space_global;
327:     stride = stride_global;
328:   }
329:   PetscScalar *T = ptr;
330:   ptr += stride;
331:   if (idx++ == nShareVec) {
332:     ptr    = work_space_global;
333:     stride = stride_global;
334:   }
335:   PetscScalar *Q = ptr;
336:   ptr += stride;
337:   if (idx++ == nShareVec) {
338:     ptr    = work_space_global;
339:     stride = stride_global;
340:   }
341:   PetscScalar *P = ptr;
342:   ptr += stride;
343:   if (idx++ == nShareVec) {
344:     ptr    = work_space_global;
345:     stride = stride_global;
346:   }
347:   PetscScalar *U = ptr;
348:   ptr += stride;
349:   if (idx++ == nShareVec) {
350:     ptr    = work_space_global;
351:     stride = stride_global;
352:   }
353:   PetscScalar *D = ptr;
354:   ptr += stride;
355:   if (idx++ == nShareVec) {
356:     ptr    = work_space_global;
357:     stride = stride_global;
358:   }
359:   PetscScalar *AUQ = V;

361:   // init: get b, zero x
362:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
363:     int rowa         = ic[rowb];
364:     R[rowb - start]  = glb_b[rowa];
365:     XX[rowb - start] = 0;
366:   });
367:   team.team_barrier();
368:   parallel_reduce(
369:     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
370:   team.team_barrier();
371:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
372:   // diagnostics
373: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
374:   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp); });
375: #endif
376:   if (dp < atol) {
377:     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
378:     return 0;
379:   }
380:   if (0 == maxit) {
381:     metad->reason = KSP_DIVERGED_ITS;
382:     return 0;
383:   }

385:   /* Make the initial Rp = R */
386:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
387:   team.team_barrier();
388:   /* Set the initial conditions */
389:   etaold = 0.0;
390:   psiold = 0.0;
391:   tau    = dp;
392:   dpold  = dp;

394:   /* rhoold = (r,rp)     */
395:   parallel_reduce(
396:     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
397:   team.team_barrier();
398:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
399:     U[idx] = R[idx];
400:     P[idx] = R[idx];
401:     T[idx] = Diag[idx] * P[idx];
402:     D[idx] = 0;
403:   });
404:   team.team_barrier();
405:   MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V);

407:   i = 0;
408:   do {
409:     /* s <- (v,rp)          */
410:     parallel_reduce(
411:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
412:     team.team_barrier();
413:     a = rhoold / s; /* a <- rho / s         */
414:     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
415:     /* t <- u + q           */
416:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
417:       Q[idx] = U[idx] - a * V[idx];
418:       T[idx] = U[idx] + Q[idx];
419:     });
420:     team.team_barrier();
421:     // KSP_PCApplyBAorAB
422:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
423:     team.team_barrier();
424:     MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ);
425:     /* r <- r - a K (u + q) */
426:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
427:     team.team_barrier();
428:     parallel_reduce(
429:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
430:     team.team_barrier();
431:     dp = PetscSqrtReal(PetscRealPart(dpi));
432:     for (m = 0; m < 2; m++) {
433:       if (!m) w = PetscSqrtReal(dp * dpold);
434:       else w = dp;
435:       psi = w / tau;
436:       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
437:       tau = tau * psi * cm;
438:       eta = cm * cm * a;
439:       cf  = psiold * psiold * etaold / a;
440:       if (!m) {
441:         /* D = U + cf D */
442:         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
443:       } else {
444:         /* D = Q + cf D */
445:         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
446:       }
447:       team.team_barrier();
448:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
449:       team.team_barrier();
450:       dpest = PetscSqrtReal(2 * i + m + 2.0) * tau;
451: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
452:       if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", i + 1, (double)dpest); });
453: #endif
454:       if (dpest < atol) {
455:         metad->reason = KSP_CONVERGED_ATOL_NORMAL;
456:         goto done;
457:       }
458:       if (dpest / r0 < rtol) {
459:         metad->reason = KSP_CONVERGED_RTOL_NORMAL;
460:         goto done;
461:       }
462: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
463:       if (dpest / r0 > dtol) {
464:         metad->reason = KSP_DIVERGED_DTOL;
465:         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), i, dpest, r0); });
466:         goto done;
467:       }
468: #else
469:       if (dpest / r0 > dtol) {
470:         metad->reason = KSP_DIVERGED_DTOL;
471:         goto done;
472:       }
473: #endif
474:       if (i + 1 == maxit) {
475:         metad->reason = KSP_DIVERGED_ITS;
476:         goto done;
477:       }

479:       etaold = eta;
480:       psiold = psi;
481:     }

483:     /* rho <- (r,rp)       */
484:     parallel_reduce(
485:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
486:     team.team_barrier();
487:     b = rho / rhoold; /* b <- rho / rhoold   */
488:     /* u <- r + b q        */
489:     /* p <- u + b(q + b p) */
490:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
491:       U[idx] = R[idx] + b * Q[idx];
492:       Q[idx] = Q[idx] + b * P[idx];
493:       P[idx] = U[idx] + b * Q[idx];
494:     });
495:     /* v <- K p  */
496:     team.team_barrier();
497:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
498:     team.team_barrier();
499:     MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V);

501:     rhoold = rho;
502:     dpold  = dp;

504:     i++;
505:   } while (i < maxit);
506: done:
507:   // KSPUnwindPreconditioner
508:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
509:   team.team_barrier();
510:   // put x into Plex order
511:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
512:     int rowa    = ic[rowb];
513:     glb_x[rowa] = XX[rowb - start];
514:   });
515:   metad->its = i + 1;
516:   if (1) {
517:     int nnz;
518:     parallel_reduce(
519:       Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
520:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
521:   } else {
522:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
523:   }
524:   return 0;
525: }

527: // Solve Ax = y with biCG
528: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
529: {
530:   using Kokkos::parallel_for;
531:   using Kokkos::parallel_reduce;
532:   int                Nblk = end - start, i, stride = stride_shared, idx = 0; // start in shared mem
533:   PetscReal          dp, r0;
534:   const PetscScalar *Di  = &glb_idiag[start];
535:   PetscScalar       *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, b, b2, ma, mac;

537:   if (idx++ == nShareVec) {
538:     ptr    = work_space_global;
539:     stride = stride_global;
540:   }
541:   PetscScalar *XX = ptr;
542:   ptr += stride;
543:   if (idx++ == nShareVec) {
544:     ptr    = work_space_global;
545:     stride = stride_global;
546:   }
547:   PetscScalar *Rl = ptr;
548:   ptr += stride;
549:   if (idx++ == nShareVec) {
550:     ptr    = work_space_global;
551:     stride = stride_global;
552:   }
553:   PetscScalar *Zl = ptr;
554:   ptr += stride;
555:   if (idx++ == nShareVec) {
556:     ptr    = work_space_global;
557:     stride = stride_global;
558:   }
559:   PetscScalar *Pl = ptr;
560:   ptr += stride;
561:   if (idx++ == nShareVec) {
562:     ptr    = work_space_global;
563:     stride = stride_global;
564:   }
565:   PetscScalar *Rr = ptr;
566:   ptr += stride;
567:   if (idx++ == nShareVec) {
568:     ptr    = work_space_global;
569:     stride = stride_global;
570:   }
571:   PetscScalar *Zr = ptr;
572:   ptr += stride;
573:   if (idx++ == nShareVec) {
574:     ptr    = work_space_global;
575:     stride = stride_global;
576:   }
577:   PetscScalar *Pr = ptr;
578:   ptr += stride;

580:   /*     r <- b (x is 0) */
581:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
582:     int rowa = ic[rowb];
583:     //VecCopy(Rr,Rl);
584:     Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
585:     XX[rowb - start]                    = 0;
586:   });
587:   team.team_barrier();
588:   /*     z <- Br         */
589:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
590:     Zr[idx] = Di[idx] * Rr[idx];
591:     Zl[idx] = Di[idx] * Rl[idx];
592:   });
593:   team.team_barrier();
594:   /*    dp <- r'*r       */
595:   parallel_reduce(
596:     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
597:   team.team_barrier();
598:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
599: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
600:   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp); });
601: #endif
602:   if (dp < atol) {
603:     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
604:     return 0;
605:   }
606:   if (0 == maxit) {
607:     metad->reason = KSP_DIVERGED_ITS;
608:     return 0;
609:   }
610:   i = 0;
611:   do {
612:     /*     beta <- r'z     */
613:     parallel_reduce(
614:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
615:     team.team_barrier();
616: #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
617:   #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
618:     Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
619:   #endif
620: #endif
621:     if (!i) {
622:       if (beta == 0.0) {
623:         metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
624:         goto done;
625:       }
626:       /*     p <- z          */
627:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
628:         Pr[idx] = Zr[idx];
629:         Pl[idx] = Zl[idx];
630:       });
631:     } else {
632:       b = beta / betaold;
633:       /*     p <- z + b* p   */
634:       b2 = PetscConj(b);
635:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
636:         Pr[idx] = b * Pr[idx] + Zr[idx];
637:         Pl[idx] = b2 * Pl[idx] + Zl[idx];
638:       });
639:     }
640:     team.team_barrier();
641:     betaold = beta;
642:     /*     z <- Kp         */
643:     MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr);
644:     MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl);
645:     /*     dpi <- z'p      */
646:     parallel_reduce(
647:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
648:     team.team_barrier();
649:     //
650:     a   = beta / dpi; /*     a = beta/p'z    */
651:     ma  = -a;
652:     mac = PetscConj(ma);
653:     /*     x <- x + ap     */
654:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
655:       XX[idx] = XX[idx] + a * Pr[idx];
656:       Rr[idx] = Rr[idx] + ma * Zr[idx];
657:       Rl[idx] = Rl[idx] + mac * Zl[idx];
658:     });
659:     team.team_barrier();
660:     team.team_barrier();
661:     /*    dp <- r'*r       */
662:     parallel_reduce(
663:       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
664:     team.team_barrier();
665:     dp = PetscSqrtReal(PetscRealPart(dpi));
666: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
667:     if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", i + 1, (double)dp); });
668: #endif
669:     if (dp < atol) {
670:       metad->reason = KSP_CONVERGED_ATOL_NORMAL;
671:       goto done;
672:     }
673:     if (dp / r0 < rtol) {
674:       metad->reason = KSP_CONVERGED_RTOL_NORMAL;
675:       goto done;
676:     }
677: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
678:     if (dp / r0 > dtol) {
679:       metad->reason = KSP_DIVERGED_DTOL;
680:       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), i, dp, r0); });
681:       goto done;
682:     }
683: #else
684:     if (dp / r0 > dtol) {
685:       metad->reason = KSP_DIVERGED_DTOL;
686:       goto done;
687:     }
688: #endif
689:     if (i + 1 == maxit) {
690:       metad->reason = KSP_DIVERGED_ITS;
691:       goto done;
692:     }
693:     /* z <- Br  */
694:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
695:       Zr[idx] = Di[idx] * Rr[idx];
696:       Zl[idx] = Di[idx] * Rl[idx];
697:     });
698:     i++;
699:     team.team_barrier();
700:   } while (i < maxit);
701: done:
702:   // put x back into Plex order
703:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
704:     int rowa    = ic[rowb];
705:     glb_x[rowa] = XX[rowb - start];
706:   });
707:   metad->its = i + 1;
708:   if (1) {
709:     int nnz;
710:     parallel_reduce(
711:       Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
712:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
713:   } else {
714:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
715:   }
716:   return 0;
717: }

719: // KSP solver solve Ax = b; x is output, bin is input
720: static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
721: {
722:   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS *)pc->data;
723:   Mat               A   = pc->pmat;
724:   Mat_SeqAIJKokkos *aijkok;

727:   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
728:   if (!aijkok) {
729:     SETERRQ(PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "No aijkok");
730:   } else {
731:     PetscInt           maxit = jac->ksp->max_it;
732:     const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
733:     const PetscInt     nwork = jac->nwork, nBlk = jac->nBlocks;
734:     PetscScalar       *glb_xdata = NULL;
735:     PetscReal          rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
736:     const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
737:     const PetscInt    *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data(), *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
738:     const PetscScalar *glb_Aaa  = aijkok->a_device_data();
739:     const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
740:     PCFailedReason     pcreason;
741:     KSPIndex           ksp_type_idx = jac->ksp_type_idx;
742:     PetscMemType       mtype;
743:     PetscContainer     container;
744:     PetscInt           batch_sz;
745:     VecScatter         plex_batch = NULL;       // not used
746:     Vec                bvec;                    // a copy of b for scatter (just alias to bin now)
747:     PetscBool          monitor  = jac->monitor; // captured
748:     PetscInt           view_bid = jac->batch_target;
749:     MatInfo            info;
750:     jac->max_nits = 0;
751:     if (view_bid < 0) view_bid = 0;
752:     MatGetInfo(A, MAT_LOCAL, &info);
753:     // get field major is to map plex IO to/from block/field major
754:     PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container);
755:     if (container) {
756:       VecDuplicate(bin, &bvec);
757:       PetscContainerGetPointer(container, (void **)&plex_batch);
758:       VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD);
759:       VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD);
760:       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
761:     } else {
762:       bvec = bin;
763:     }
764:     // get x
765:     VecGetArrayAndMemType(xout, &glb_xdata, &mtype);
766: #if defined(PETSC_HAVE_CUDA)
768: #endif
769:     VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype);
770: #if defined(PETSC_HAVE_CUDA)
772: #endif
773:     // get batch size
774:     PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container);
775:     if (container) {
776:       PetscInt *pNf = NULL;
777:       PetscContainerGetPointer(container, (void **)&pNf);
778:       batch_sz = *pNf;
779:     } else batch_sz = 1;
781:     if (ksp_type_idx == BATCH_KSP_GMRES_IDX) { // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
782: #if defined(PETSC_HAVE_KOKKOS_KERNELS_GMRES)
783:       int       Nsolves_team = jac->nsolves_team, fill_idx = 0;
784:       int       Nloc    = jac->const_block_size; // same grids
785:       const int Nsolves = nBlk;
786:       const int nnz     = (int)info.nz_used / Nsolves;      // fix for variable grid size
787:       if (Nsolves_team > batch_sz) Nsolves_team = batch_sz; // silently fix this
791:   #if defined(PETSC_HAVE_CUDA)
792:       nvtxRangePushA("gmres-kk");
793:   #endif
794:       Kokkos::View<PetscScalar **, layout, exec_space, Kokkos::MemoryTraits<Kokkos::Unmanaged>> inv_diag((PetscScalar *)glb_idiag, Nsolves, Nloc); // in correct order
795:       if (!jac->rowOffsets) {
796:         jac->rowOffsets   = new IntView("rowOffsets", Nsolves / Nsolves_team, Nloc + 1); // same grids
797:         jac->colIndices   = new IntView("colIndices", Nsolves / Nsolves_team, nnz);
798:         jac->batch_b      = new XYType("batch rhs", Nsolves, Nloc);
799:         jac->batch_x      = new XYType("batch sol", Nsolves, Nloc);
800:         jac->batch_values = new AMatrixValueView("batch values", Nsolves, nnz);
801:         fill_idx          = 1;
802:         PetscInfo(pc, "Setup indices Nloc=%d, nnz=%d\n", Nloc, nnz);
803:       }
804:       IntView          &rowOffsets   = *jac->rowOffsets;
805:       IntView          &colIndices   = *jac->colIndices;
806:       XYType           &batch_b      = *jac->batch_b;
807:       XYType           &batch_x      = *jac->batch_x;
808:       AMatrixValueView &batch_values = *jac->batch_values;

810:       Kokkos::deep_copy(batch_x, 0.);
811:       PetscInfo(pc, "\tjac->n = %" PetscInt_FMT ", Nloc = %d, Nsolves = %d, nnz = %d, Nsolves_team = %d, league size = %d, maxit = %" PetscInt_FMT "\n", jac->n, Nloc, Nsolves, nnz, Nsolves_team, Nsolves / Nsolves_team, maxit);
812:       Kokkos::parallel_for(
813:         "rowOffsets+map", Kokkos::TeamPolicy<>(Nsolves, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
814:           const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
815:           if (fill_idx) {
816:             if (blkID % Nsolves_team == 0) {                                                        // first matrix on this member
817:               Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](const int rowb) { // Nloc
818:                 int rowa                                           = d_isicol[rowb];
819:                 int n                                              = glb_Aai[rowa + 1] - glb_Aai[rowa];
820:                 rowOffsets(blkID / Nsolves_team, rowb + 1 - start) = n; // save sizes
821:               });
822:             }
823:           }
824:           // map b into field major space
825:           Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
826:             int rowa                     = d_isicol[rowb];
827:             batch_b(blkID, rowb - start) = glb_bdata[rowa];
828:           });
829:         });
830:       Kokkos::fence();
831:       if (fill_idx) {
832:         Kokkos::parallel_for(
833:           "prefix sum", Kokkos::TeamPolicy<>(Nsolves / Nsolves_team, 1, 1), KOKKOS_LAMBDA(const team_member team) {
834:             const int graphID      = team.league_rank();
835:             rowOffsets(graphID, 0) = 0;
836:             for (size_t i = 0; i < Nloc; ++i) rowOffsets(graphID, i + 1) += rowOffsets(graphID, i);
837:           });
838:         Kokkos::fence();
839:       }
840:       Kokkos::parallel_for(
841:         "copy matrix", Kokkos::TeamPolicy<>(Nsolves /* /batch_sz */, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
842:           const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1], graphID = blkID / Nsolves_team;
843:           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
844:             int                rowa = d_isicol[rowb]; // global index
845:             int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
846:             const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa];
847:             const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
848:             Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
849:               PetscScalar val = aa[i];
850:               if (fill_idx && blkID % Nsolves_team == 0) colIndices(graphID, rowOffsets(graphID, rowb - start) + i) = d_isrow[aj[i]] - blkID * Nloc; // local" global - block start
851:               batch_values(blkID, rowOffsets(graphID, rowb - start) + i) = val;
852:             });
853:           });
854:         });
855:       Kokkos::fence();
856:       // setup solver
857:       using ScalarType                = typename AMatrixValueView::non_const_value_type;
858:       using MagnitudeType             = typename Kokkos::Details::ArithTraits<ScalarType>::mag_type;
859:       using NormViewType              = Kokkos::View<MagnitudeType *, layout, exec_space>;
860:       using Norm2DViewType            = Kokkos::View<MagnitudeType **, layout, exec_space>;
861:       using Scalar3DViewType          = Kokkos::View<ScalarType ***, layout, exec_space>;
862:       using IntViewType               = Kokkos::View<int *, layout, exec_space>;
863:       using KrylovHandleType          = KokkosBatched::KrylovHandle<Norm2DViewType, IntViewType, Scalar3DViewType>;
864:       const int        n_iterations   = maxit;
865:       const int        team_size      = -1;
866:       const int        vector_length  = -1;
867:       const double     tol            = rtol;
868:       const int        ortho_strategy = 0;
869:       KrylovHandleType handle(Nsolves, Nsolves_team, n_iterations, true);
870:       handle.Arnoldi_view = Scalar3DViewType("", Nsolves, n_iterations, Nloc + n_iterations + 3);
871:       // solve
872:       double time = Functor_TestBatchedTeamVectorGMRES<exec_space, AMatrixValueView, IntView, XYType, KrylovHandleType>(batch_values, inv_diag, rowOffsets, colIndices, batch_x, batch_b, Nsolves_team, team_size, vector_length, n_iterations, tol, ortho_strategy, 0, handle)
873:                       .run(pc);
874:       Kokkos::fence();
875:       // get data back
876:       Kokkos::parallel_for(
877:         "map", Kokkos::TeamPolicy<>(Nsolves /* /batch_sz */, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
878:           const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1]; // 0
879:           // map x into Plex/PETSc
880:           Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
881:             int rowa        = d_isicol[rowb];
882:             glb_xdata[rowa] = batch_x(blkID, rowb - start);
883:           });
884:         });
885:       // output assume species major - clone from Kokkos solvers
886:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
887:     #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
888:       PetscPrintf(PetscObjectComm((PetscObject)A), "Iterations\n");
889:     #else
890:       PetscPrintf(PetscObjectComm((PetscObject)A), "max iterations per species (gmres) :");
891:     #endif
892:       for (PetscInt dmIdx = 0, s = 0, head = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
893:         for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, s++, idx++) {
894:     #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
895:           PetscPrintf(PetscObjectComm((PetscObject)A), "%2D:", s);
896:           for (int bid = 0; bid < batch_sz; bid++) PetscPrintf(PetscObjectComm((PetscObject)A), "%3D ", handle.get_iteration_host(idx + bid * jac->dm_Nf[dmIdx]));
897:           PetscPrintf(PetscObjectComm((PetscObject)A), "\n");
898:     #else
899:           int count = 0, ii;
900:           for (int bid = 0; bid < batch_sz; bid++) {
901:             if ((ii = handle.get_iteration_host(idx + bid * jac->dm_Nf[dmIdx])) > count) count = ii;
902:           }
903:           PetscPrintf(PetscObjectComm((PetscObject)A), "%3d", count);
904:     #endif
905:         }
906:         head += batch_sz * jac->dm_Nf[dmIdx];
907:       }
908:     #if PCBJKOKKOS_VERBOSE_LEVEL == 3
909:       PetscPrintf(PetscObjectComm((PetscObject)A), "\n");
910:     #endif
911:   #endif
912:       // return error code, get max it
913:       PetscInt count = 0, mbid = 0;
914:       if (handle.is_converged_host()) {
915:         pcreason = PC_NOERROR;
916:         if (!jac->max_nits) {
917:           for (int blkID = 0; blkID < nBlk; blkID++) {
918:             if (handle.get_iteration_host(blkID) > jac->max_nits) {
919:               jac->max_nits = handle.get_iteration_host(blkID);
920:               mbid          = blkID;
921:             }
922:           }
923:         }
924:       } else {
925:         PetscPrintf(PETSC_COMM_SELF, "There is at least one system that did not converge.");
926:         pcreason = PC_SUBPC_ERROR;
927:       }
928:       // output - assume species major order
929:       for (int blkID = 0; blkID < nBlk; blkID++) {
930:         if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
931:           if (jac->batch_target == blkID) {
932:             PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iterations, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", handle.is_converged_host(blkID) ? "converged" : "diverged", handle.get_iteration_host(blkID), blkID % batch_sz, blkID / batch_sz);
933:           } else if (jac->batch_target == -1 && handle.get_iteration_host(blkID) > count) {
934:             jac->max_nits = count = handle.get_iteration_host(blkID);
935:             mbid                  = blkID;
936:           }
937:           if (!handle.is_converged_host(blkID)) PetscPrintf(PETSC_COMM_SELF, "ERROR species %d, batch %d did not converge with %d iterations\n", (int)(blkID / batch_sz), (int)blkID % batch_sz, handle.get_iteration_host(blkID));
938:         }
939:       }
940:       if (jac->batch_target == -1 && jac->reason) {
941:         PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iteration, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", handle.is_converged_host(mbid) ? "converged" : "diverged", jac->max_nits, mbid % batch_sz, mbid / batch_sz);
942:       }
943:   #if defined(PETSC_HAVE_CUDA)
944:       nvtxRangePop();
945:   #endif
946: #else
947:       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "batch GMRES not supported");
948: #endif
949:     } else { // Kokkos Krylov
950:       using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
951:       using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
952:       Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
953:       int                                                           stride_shared, stride_global, global_buff_words;
954:       d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
955:       // solve each block independently
956:       int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
957:       if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
958:         int         maximum_shared_mem_size = 64000;
959:         PetscDevice device;
960:         PetscDeviceGetDefault_Internal(&device);
961:         PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size);
962:         stride_shared = jac->const_block_size;                                                   // captured
963:         nShareVec     = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
964:         if (nShareVec > nwork) nShareVec = nwork;
965:         else nGlobBVec = nwork - nShareVec;
966:         global_buff_words     = jac->n * nGlobBVec;
967:         scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
968:         //PetscPrintf(PETSC_COMM_WORLD,"maximum_shared_mem_size=%d scr_bytes_shared=%d nShareVec=%d, nGlobBVec=%d vec size=%d jac->const_block_size=%d\n",maximum_shared_mem_size,scr_bytes_team_shared,nShareVec,nGlobBVec,jac->const_block_size*sizeof(PetscScalar),jac->const_block_size);
969:       } else {
970:         scr_bytes_team_shared = 0;
971:         stride_shared         = 0;
972:         global_buff_words     = jac->n * nwork;
973:         nGlobBVec             = nwork; // not needed == fix
974:       }
975:       stride_global = jac->n; // captured
976: #if defined(PETSC_HAVE_CUDA)
977:       nvtxRangePushA("batch-kokkos-solve");
978: #endif
979:       Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
980:       PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec);
981:       PetscScalar *d_work_vecs = d_work_vecs_k.data();
982:       Kokkos::parallel_for(
983:         "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
984:           const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
985:           vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
986:           PetscScalar *work_buff_shared = work_vecs_shared.data();
987:           PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
988:           bool         print            = monitor && (blkID == view_bid);
989:           switch (ksp_type_idx) {
990:           case BATCH_KSP_BICG_IDX:
991:             BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
992:             break;
993:           case BATCH_KSP_TFQMR_IDX:
994:             BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
995:             break;
996:           case BATCH_KSP_GMRES_IDX:
997:             //BJSolve_GMRES();
998:             break;
999:           default:
1000: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1001:             printf("Unknown KSP type %d\n", ksp_type_idx);
1002: #else
1003:             /* void */;
1004: #endif
1005:           }
1006:         });
1007:       Kokkos::fence();
1008: #if defined(PETSC_HAVE_CUDA)
1009:       nvtxRangePop();
1010:       nvtxRangePushA("Post-solve-metadata");
1011: #endif
1012:       auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
1013:       Kokkos::deep_copy(h_metadata, d_metadata);
1014: #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
1015:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
1016:       PetscPrintf(PETSC_COMM_WORLD, "Iterations\n");
1017:   #endif
1018:       // assume species major
1019:   #if PCBJKOKKOS_VERBOSE_LEVEL < 4
1020:       PetscPrintf(PetscObjectComm((PetscObject)A), "max iterations per species (%s) :", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr");
1021:   #endif
1022:       for (PetscInt dmIdx = 0, s = 0, head = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
1023:         for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, s++, idx++) {
1024:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
1025:           PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s);
1026:           for (int bid = 0; bid < batch_sz; bid++) PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its);
1027:           PetscPrintf(PetscObjectComm((PetscObject)A), "\n");
1028:   #else
1029:           PetscInt count = 0;
1030:           for (int bid = 0; bid < batch_sz; bid++) {
1031:             if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
1032:           }
1033:           PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", count);
1034:   #endif
1035:         }
1036:         head += batch_sz * jac->dm_Nf[dmIdx];
1037:       }
1038:   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
1039:       PetscPrintf(PetscObjectComm((PetscObject)A), "\n");
1040:   #endif
1041: #endif
1042:       PetscInt count = 0, mbid = 0;
1043:       for (int blkID = 0; blkID < nBlk; blkID++) {
1044:         PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops);
1045:         if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
1046:           if (jac->batch_target == blkID) {
1047:             PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID % batch_sz, blkID / batch_sz);
1048:           } else if (jac->batch_target == -1 && h_metadata[blkID].its > count) {
1049:             jac->max_nits = count = h_metadata[blkID].its;
1050:             mbid                  = blkID;
1051:           }
1052:           if (h_metadata[blkID].reason < 0) {
1053:             PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz);
1054:           }
1055:         }
1056:       }
1057:       if (jac->batch_target == -1 && jac->reason) {
1058:         PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[mbid].reason], h_metadata[mbid].its, mbid % batch_sz, mbid / batch_sz);
1059:       }
1060:       {
1061:         int errsum;
1062:         Kokkos::parallel_reduce(
1063:           nBlk,
1064:           KOKKOS_LAMBDA(const int idx, int &lsum) {
1065:             if (d_metadata[idx].reason < 0) ++lsum;
1066:           },
1067:           errsum);
1068:         pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
1069:         if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
1070:           for (int blkID = 0; blkID < nBlk; blkID++) {
1071:             if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
1072:           }
1073:         } else if (errsum) {
1074:           PetscPrintf(PETSC_COMM_SELF, "ERROR Kokkos batch solver did not converge in all solves\n");
1075:         }
1076:       }
1077: #if defined(PETSC_HAVE_CUDA)
1078:       nvtxRangePop();
1079: #endif
1080:     } // end of Kokkos (not Kernels) solvers block
1081:     VecRestoreArrayAndMemType(xout, &glb_xdata);
1082:     VecRestoreArrayReadAndMemType(bvec, &glb_bdata);
1083:     PCSetFailedReason(pc, pcreason);
1084:     // map back to Plex space - not used
1085:     if (plex_batch) {
1086:       VecCopy(xout, bvec);
1087:       VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE);
1088:       VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE);
1089:       VecDestroy(&bvec);
1090:     }
1091:   } // whole 'have aijkok' block
1092:   return 0;
1093: }

1095: static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
1096: {
1097:   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS *)pc->data;
1098:   Mat               A   = pc->pmat;
1099:   Mat_SeqAIJKokkos *aijkok;
1100:   PetscBool         flg;

1104:   PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, "");
1106:   if (!(aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr))) {
1107:     SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No aijkok");
1108:   } else {
1109:     if (!jac->vec_diag) {
1110:       Vec           *subX;
1111:       DM             pack, *subDM;
1112:       PetscInt       nDMs, n;
1113:       PetscContainer container;
1114:       PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container);
1115:       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
1116:         MatOrderingType rtype;
1117:         IS              isrow, isicol;
1118:         const PetscInt *rowindices, *icolindices;
1119:         rtype = MATORDERINGRCM;
1120:         // get permutation. Not what I expect so inverted here
1121:         MatGetOrdering(A, rtype, &isrow, &isicol);
1122:         ISDestroy(&isrow);
1123:         ISInvertPermutation(isicol, PETSC_DECIDE, &isrow); // THIS IS BACKWARD -- isrow is inverse -- FIX!!!!!

1125:         Mat mat_block_order;
1126:         MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order);
1127:         MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view");
1128:         MatDestroy(&mat_block_order);

1130:         ISGetIndices(isrow, &rowindices);
1131:         ISGetIndices(isicol, &icolindices);
1132:         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
1133:         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
1134:         jac->d_isrow_k  = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
1135:         jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
1136:         Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
1137:         Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
1138:         ISRestoreIndices(isrow, &rowindices);
1139:         ISRestoreIndices(isicol, &icolindices);
1140:         ISDestroy(&isrow);
1141:         ISDestroy(&isicol);
1142:       }
1143:       // get block sizes
1144:       PCGetDM(pc, &pack);
1146:       PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg);
1148:       DMCompositeGetNumberDM(pack, &nDMs);
1149:       jac->num_dms = nDMs;
1150:       DMCreateGlobalVector(pack, &jac->vec_diag);
1151:       VecGetLocalSize(jac->vec_diag, &n);
1152:       jac->n         = n;
1153:       jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
1154:       // options
1155:       PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
1156:       KSPSetFromOptions(jac->ksp);
1157:       PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, "");
1158:       if (flg) {
1159:         jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
1160:         jac->nwork        = 7;
1161:       } else {
1162:         PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, "");
1163:         if (flg) {
1164:           jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
1165:           jac->nwork        = 10;
1166:         } else {
1167:           PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, "");
1168:           if (flg) {
1169:             jac->ksp_type_idx = BATCH_KSP_GMRES_IDX;
1170:             jac->nwork        = 0;
1171:           } else {
1172:             SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
1173:           }
1174:         }
1175:       }
1176:       PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
1177:       PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL);
1178:       PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL);
1179:       PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL);
1180:       PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL);
1182:       PetscOptionsEnd();
1183:       // get blocks - jac->d_bid_eqOffset_k
1184:       PetscMalloc(sizeof(*subX) * nDMs, &subX);
1185:       PetscMalloc(sizeof(*subDM) * nDMs, &subDM);
1186:       PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf);
1187:       PetscInfo(pc, "Have %" PetscInt_FMT " DMs, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name);
1188:       DMCompositeGetEntriesArray(pack, subDM);
1189:       jac->nBlocks = 0;
1190:       for (PetscInt ii = 0; ii < nDMs; ii++) {
1191:         PetscSection section;
1192:         PetscInt     Nf;
1193:         DM           dm = subDM[ii];
1194:         DMGetLocalSection(dm, &section);
1195:         PetscSectionGetNumFields(section, &Nf);
1196:         jac->nBlocks += Nf;
1197: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
1198:         if (ii == 0) PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks);
1199: #else
1200:         PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks);
1201: #endif
1202:         jac->dm_Nf[ii] = Nf;
1203:       }
1204:       { // d_bid_eqOffset_k
1205:         Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
1206:         DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX);
1207:         h_block_offsets[0]    = 0;
1208:         jac->const_block_size = -1;
1209:         for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
1210:           PetscInt nloc, nblk;
1211:           VecGetSize(subX[ii], &nloc);
1212:           nblk = nloc / jac->dm_Nf[ii];
1214:           for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
1215:             h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
1216: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
1217:             if (idx == 0) PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks);
1218: #else
1219:             PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks);
1220: #endif
1221:             if (jac->const_block_size == -1) jac->const_block_size = nblk;
1222:             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
1223:           }
1224:         }
1225:         DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX);
1226:         PetscFree(subX);
1227:         PetscFree(subDM);
1228:         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
1229:         Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
1230:       }
1231:     }
1232:     { // get jac->d_idiag_k (PC setup),
1233:       const PetscInt    *d_ai = aijkok->i_device_data(), *d_aj = aijkok->j_device_data();
1234:       const PetscScalar *d_aa = aijkok->a_device_data();
1235:       const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
1236:       PetscInt          *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
1237:       PetscScalar       *d_idiag = jac->d_idiag_k->data();
1238:       Kokkos::parallel_for(
1239:         "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
1240:           const PetscInt blkID = team.league_rank();
1241:           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
1242:             const PetscInt     rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
1243:             const PetscScalar *aa   = d_aa + ai;
1244:             const PetscInt     nrow = d_ai[rowa + 1] - ai;
1245:             int                found;
1246:             Kokkos::parallel_reduce(
1247:               Kokkos::ThreadVectorRange(team, nrow),
1248:               [=](const int &j, int &count) {
1249:                 const PetscInt colb = r[aj[j]];
1250:                 if (colb == rowb) {
1251:                   d_idiag[rowb] = 1. / aa[j];
1252:                   count++;
1253:                 }
1254:               },
1255:               found);
1256: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1257:             if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1258: #endif
1259:           });
1260:         });
1261:     }
1262:   }
1263:   return 0;
1264: }

1266: /* Default destroy, if it has never been setup */
1267: static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1268: {
1269:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1271:   KSPDestroy(&jac->ksp);
1272:   VecDestroy(&jac->vec_diag);
1273:   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1274:   if (jac->d_idiag_k) delete jac->d_idiag_k;
1275:   if (jac->d_isrow_k) delete jac->d_isrow_k;
1276:   if (jac->d_isicol_k) delete jac->d_isicol_k;
1277:   jac->d_bid_eqOffset_k = NULL;
1278:   jac->d_idiag_k        = NULL;
1279:   jac->d_isrow_k        = NULL;
1280:   jac->d_isicol_k       = NULL;
1281:   PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL); // not published now (causes configure errors)
1282:   PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL);
1283:   PetscFree(jac->dm_Nf);
1284:   jac->dm_Nf = NULL;
1285:   if (jac->rowOffsets) delete jac->rowOffsets;
1286:   if (jac->colIndices) delete jac->colIndices;
1287:   if (jac->batch_b) delete jac->batch_b;
1288:   if (jac->batch_x) delete jac->batch_x;
1289:   if (jac->batch_values) delete jac->batch_values;
1290:   jac->rowOffsets   = NULL;
1291:   jac->colIndices   = NULL;
1292:   jac->batch_b      = NULL;
1293:   jac->batch_x      = NULL;
1294:   jac->batch_values = NULL;

1296:   return 0;
1297: }

1299: static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1300: {
1301:   PCReset_BJKOKKOS(pc);
1302:   PetscFree(pc->data);
1303:   return 0;
1304: }

1306: static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1307: {
1308:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1309:   PetscBool      iascii;

1311:   if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
1312:   PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii);
1313:   if (iascii) {
1314:     PetscViewerASCIIPrintf(viewer, "  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n");
1315:     PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
1316:                                      ((PetscObject)jac->ksp)->type_name));
1317:   }
1318:   return 0;
1319: }

1321: static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1322: {
1323:   PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1324:   PetscOptionsHeadEnd();
1325:   return 0;
1326: }

1328: static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1329: {
1330:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1332:   PetscObjectReference((PetscObject)ksp);
1333:   KSPDestroy(&jac->ksp);
1334:   jac->ksp = ksp;
1335:   return 0;
1336: }

1338: /*@C
1339:    PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`

1341:    Collective

1343:    Input Parameters:
1344: +  pc - the `PCBJKOKKOS` preconditioner context
1345: -  ksp - the `KSP` solver

1347:    Notes:
1348:    The `PC` and the `KSP` must have the same communicator

1350:    If the `PC` is not `PCBJKOKKOS` this function returns without doing anything

1352:    Level: advanced

1354: ,seealso: `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1355: @*/
1356: PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1357: {
1361:   PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
1362:   return 0;
1363: }

1365: static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1366: {
1367:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1369:   if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
1370:   *ksp = jac->ksp;
1371:   return 0;
1372: }

1374: /*@C
1375:    PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner

1377:    Not Collective but `KSP` returned is parallel if `PC` was parallel

1379:    Input Parameter:
1380: .  pc - the preconditioner context

1382:    Output Parameter:
1383: .  ksp - the `KSP` solver

1385:    Notes:
1386:    You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.

1388:    If the `PC` is not a `PCBJKOKKOS` object it raises an error

1390:    Level: advanced

1392: .seealso: `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1393: @*/
1394: PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1395: {
1398:   PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
1399:   return 0;
1400: }

1402: static PetscErrorCode PCPostSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1403: {
1404:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1406:   ksp->its = jac->max_nits;
1407:   return 0;
1408: }

1410: /*MC
1411:      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a `MATSEQAIJ` matrix on the GPU using Kokkos

1413:    Options Database Key:
1414: .     -pc_bjkokkos_ - options prefix for its `KSP` options

1416:    Level: intermediate

1418:    Note:
1419:     For use with -ksp_type preonly to bypass any computation on the CPU

1421:    Developer Notes:
1422:    The documentation is incomplete. Is this a block Jacobi preconditioner?

1424:    Why does it have its own `KSP`? Where is the `KSP` run if used with -ksp_type preonly?

1426: .seealso: `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1427:           `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1428: M*/

1430: PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1431: {
1432:   PC_PCBJKOKKOS *jac;

1434:   PetscNew(&jac);
1435:   pc->data = (void *)jac;

1437:   jac->ksp              = NULL;
1438:   jac->vec_diag         = NULL;
1439:   jac->d_bid_eqOffset_k = NULL;
1440:   jac->d_idiag_k        = NULL;
1441:   jac->d_isrow_k        = NULL;
1442:   jac->d_isicol_k       = NULL;
1443:   jac->nBlocks          = 1;
1444:   jac->max_nits         = 0;

1446:   PetscMemzero(pc->ops, sizeof(struct _PCOps));
1447:   pc->ops->apply          = PCApply_BJKOKKOS;
1448:   pc->ops->applytranspose = NULL;
1449:   pc->ops->setup          = PCSetUp_BJKOKKOS;
1450:   pc->ops->reset          = PCReset_BJKOKKOS;
1451:   pc->ops->destroy        = PCDestroy_BJKOKKOS;
1452:   pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1453:   pc->ops->view           = PCView_BJKOKKOS;
1454:   pc->ops->postsolve      = PCPostSolve_BJKOKKOS;

1456:   jac->rowOffsets   = NULL;
1457:   jac->colIndices   = NULL;
1458:   jac->batch_b      = NULL;
1459:   jac->batch_x      = NULL;
1460:   jac->batch_values = NULL;

1462:   PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS);
1463:   PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS);
1464:   return 0;
1465: }