Actual source code: baijfact81.c
2: /*
3: Factorization code for BAIJ format.
4: */
5: #include <../src/mat/impls/baij/seq/baij.h>
6: #include <petsc/private/kernels/blockinvert.h>
7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
8: #include <immintrin.h>
9: #endif
10: /*
11: Version for when blocks are 9 by 9
12: */
13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
15: {
16: Mat C = B;
17: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
18: PetscInt i, j, k, nz, nzL, row;
19: const PetscInt n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
20: const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
21: MatScalar *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
22: PetscInt flg;
23: PetscReal shift = info->shiftamount;
24: PetscBool allowzeropivot, zeropivotdetected;
26: allowzeropivot = PetscNot(A->erroriffailure);
28: /* generate work space needed by the factorization */
29: PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork);
30: PetscArrayzero(rtmp, bs2 * n);
32: for (i = 0; i < n; i++) {
33: /* zero rtmp */
34: /* L part */
35: nz = bi[i + 1] - bi[i];
36: bjtmp = bj + bi[i];
37: for (j = 0; j < nz; j++) PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2);
39: /* U part */
40: nz = bdiag[i] - bdiag[i + 1];
41: bjtmp = bj + bdiag[i + 1] + 1;
42: for (j = 0; j < nz; j++) PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2);
44: /* load in initial (unfactored row) */
45: nz = ai[i + 1] - ai[i];
46: ajtmp = aj + ai[i];
47: v = aa + bs2 * ai[i];
48: for (j = 0; j < nz; j++) PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2);
50: /* elimination */
51: bjtmp = bj + bi[i];
52: nzL = bi[i + 1] - bi[i];
53: for (k = 0; k < nzL; k++) {
54: row = bjtmp[k];
55: pc = rtmp + bs2 * row;
56: for (flg = 0, j = 0; j < bs2; j++) {
57: if (pc[j] != 0.0) {
58: flg = 1;
59: break;
60: }
61: }
62: if (flg) {
63: pv = b->a + bs2 * bdiag[row];
64: /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
65: PetscKernel_A_gets_A_times_B_9(pc, pv, mwork);
67: pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
68: pv = b->a + bs2 * (bdiag[row + 1] + 1);
69: nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
70: for (j = 0; j < nz; j++) {
71: /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
72: /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
73: v = rtmp + bs2 * pj[j];
74: PetscKernel_A_gets_A_minus_B_times_C_9(v, pc, pv + 81 * j);
75: /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
76: }
77: PetscLogFlops(1458 * nz + 1377); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
78: }
79: }
81: /* finished row so stick it into b->a */
82: /* L part */
83: pv = b->a + bs2 * bi[i];
84: pj = b->j + bi[i];
85: nz = bi[i + 1] - bi[i];
86: for (j = 0; j < nz; j++) PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2);
88: /* Mark diagonal and invert diagonal for simpler triangular solves */
89: pv = b->a + bs2 * bdiag[i];
90: pj = b->j + bdiag[i];
91: PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2);
92: PetscKernel_A_gets_inverse_A_9(pv, shift, allowzeropivot, &zeropivotdetected);
93: if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
95: /* U part */
96: pv = b->a + bs2 * (bdiag[i + 1] + 1);
97: pj = b->j + bdiag[i + 1] + 1;
98: nz = bdiag[i] - bdiag[i + 1] - 1;
99: for (j = 0; j < nz; j++) PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2);
100: }
101: PetscFree2(rtmp, mwork);
103: C->ops->solve = MatSolve_SeqBAIJ_9_NaturalOrdering;
104: C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
105: C->assembled = PETSC_TRUE;
107: PetscLogFlops(1.333333333333 * 9 * 9 * 9 * n); /* from inverting diagonal blocks */
108: return 0;
109: }
111: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A, Vec bb, Vec xx)
112: {
113: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
114: const PetscInt *ai = a->i, *aj = a->j, *adiag = a->diag, *vi;
115: PetscInt i, k, n = a->mbs;
116: PetscInt nz, bs = A->rmap->bs, bs2 = a->bs2;
117: const MatScalar *aa = a->a, *v;
118: PetscScalar *x, *s, *t, *ls;
119: const PetscScalar *b;
120: __m256d a0, a1, a2, a3, a4, a5, w0, w1, w2, w3, s0, s1, s2, v0, v1, v2, v3;
122: VecGetArrayRead(bb, &b);
123: VecGetArray(xx, &x);
124: t = a->solve_work;
126: /* forward solve the lower triangular */
127: PetscArraycpy(t, b, bs); /* copy 1st block of b to t */
129: for (i = 1; i < n; i++) {
130: v = aa + bs2 * ai[i];
131: vi = aj + ai[i];
132: nz = ai[i + 1] - ai[i];
133: s = t + bs * i;
134: PetscArraycpy(s, b + bs * i, bs); /* copy i_th block of b to t */
136: __m256d s0, s1, s2;
137: s0 = _mm256_loadu_pd(s + 0);
138: s1 = _mm256_loadu_pd(s + 4);
139: s2 = _mm256_maskload_pd(s + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
141: for (k = 0; k < nz; k++) {
142: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
143: a0 = _mm256_loadu_pd(&v[0]);
144: s0 = _mm256_fnmadd_pd(a0, w0, s0);
145: a1 = _mm256_loadu_pd(&v[4]);
146: s1 = _mm256_fnmadd_pd(a1, w0, s1);
147: a2 = _mm256_loadu_pd(&v[8]);
148: s2 = _mm256_fnmadd_pd(a2, w0, s2);
150: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
151: a3 = _mm256_loadu_pd(&v[9]);
152: s0 = _mm256_fnmadd_pd(a3, w1, s0);
153: a4 = _mm256_loadu_pd(&v[13]);
154: s1 = _mm256_fnmadd_pd(a4, w1, s1);
155: a5 = _mm256_loadu_pd(&v[17]);
156: s2 = _mm256_fnmadd_pd(a5, w1, s2);
158: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
159: a0 = _mm256_loadu_pd(&v[18]);
160: s0 = _mm256_fnmadd_pd(a0, w2, s0);
161: a1 = _mm256_loadu_pd(&v[22]);
162: s1 = _mm256_fnmadd_pd(a1, w2, s1);
163: a2 = _mm256_loadu_pd(&v[26]);
164: s2 = _mm256_fnmadd_pd(a2, w2, s2);
166: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
167: a3 = _mm256_loadu_pd(&v[27]);
168: s0 = _mm256_fnmadd_pd(a3, w3, s0);
169: a4 = _mm256_loadu_pd(&v[31]);
170: s1 = _mm256_fnmadd_pd(a4, w3, s1);
171: a5 = _mm256_loadu_pd(&v[35]);
172: s2 = _mm256_fnmadd_pd(a5, w3, s2);
174: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
175: a0 = _mm256_loadu_pd(&v[36]);
176: s0 = _mm256_fnmadd_pd(a0, w0, s0);
177: a1 = _mm256_loadu_pd(&v[40]);
178: s1 = _mm256_fnmadd_pd(a1, w0, s1);
179: a2 = _mm256_loadu_pd(&v[44]);
180: s2 = _mm256_fnmadd_pd(a2, w0, s2);
182: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
183: a3 = _mm256_loadu_pd(&v[45]);
184: s0 = _mm256_fnmadd_pd(a3, w1, s0);
185: a4 = _mm256_loadu_pd(&v[49]);
186: s1 = _mm256_fnmadd_pd(a4, w1, s1);
187: a5 = _mm256_loadu_pd(&v[53]);
188: s2 = _mm256_fnmadd_pd(a5, w1, s2);
190: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
191: a0 = _mm256_loadu_pd(&v[54]);
192: s0 = _mm256_fnmadd_pd(a0, w2, s0);
193: a1 = _mm256_loadu_pd(&v[58]);
194: s1 = _mm256_fnmadd_pd(a1, w2, s1);
195: a2 = _mm256_loadu_pd(&v[62]);
196: s2 = _mm256_fnmadd_pd(a2, w2, s2);
198: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
199: a3 = _mm256_loadu_pd(&v[63]);
200: s0 = _mm256_fnmadd_pd(a3, w3, s0);
201: a4 = _mm256_loadu_pd(&v[67]);
202: s1 = _mm256_fnmadd_pd(a4, w3, s1);
203: a5 = _mm256_loadu_pd(&v[71]);
204: s2 = _mm256_fnmadd_pd(a5, w3, s2);
206: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
207: a0 = _mm256_loadu_pd(&v[72]);
208: s0 = _mm256_fnmadd_pd(a0, w0, s0);
209: a1 = _mm256_loadu_pd(&v[76]);
210: s1 = _mm256_fnmadd_pd(a1, w0, s1);
211: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
212: s2 = _mm256_fnmadd_pd(a2, w0, s2);
213: v += bs2;
214: }
215: _mm256_storeu_pd(&s[0], s0);
216: _mm256_storeu_pd(&s[4], s1);
217: _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
218: }
220: /* backward solve the upper triangular */
221: ls = a->solve_work + A->cmap->n;
222: for (i = n - 1; i >= 0; i--) {
223: v = aa + bs2 * (adiag[i + 1] + 1);
224: vi = aj + adiag[i + 1] + 1;
225: nz = adiag[i] - adiag[i + 1] - 1;
226: PetscArraycpy(ls, t + i * bs, bs);
228: s0 = _mm256_loadu_pd(ls + 0);
229: s1 = _mm256_loadu_pd(ls + 4);
230: s2 = _mm256_maskload_pd(ls + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
232: for (k = 0; k < nz; k++) {
233: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
234: a0 = _mm256_loadu_pd(&v[0]);
235: s0 = _mm256_fnmadd_pd(a0, w0, s0);
236: a1 = _mm256_loadu_pd(&v[4]);
237: s1 = _mm256_fnmadd_pd(a1, w0, s1);
238: a2 = _mm256_loadu_pd(&v[8]);
239: s2 = _mm256_fnmadd_pd(a2, w0, s2);
241: /* v += 9; */
242: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
243: a3 = _mm256_loadu_pd(&v[9]);
244: s0 = _mm256_fnmadd_pd(a3, w1, s0);
245: a4 = _mm256_loadu_pd(&v[13]);
246: s1 = _mm256_fnmadd_pd(a4, w1, s1);
247: a5 = _mm256_loadu_pd(&v[17]);
248: s2 = _mm256_fnmadd_pd(a5, w1, s2);
250: /* v += 9; */
251: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
252: a0 = _mm256_loadu_pd(&v[18]);
253: s0 = _mm256_fnmadd_pd(a0, w2, s0);
254: a1 = _mm256_loadu_pd(&v[22]);
255: s1 = _mm256_fnmadd_pd(a1, w2, s1);
256: a2 = _mm256_loadu_pd(&v[26]);
257: s2 = _mm256_fnmadd_pd(a2, w2, s2);
259: /* v += 9; */
260: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
261: a3 = _mm256_loadu_pd(&v[27]);
262: s0 = _mm256_fnmadd_pd(a3, w3, s0);
263: a4 = _mm256_loadu_pd(&v[31]);
264: s1 = _mm256_fnmadd_pd(a4, w3, s1);
265: a5 = _mm256_loadu_pd(&v[35]);
266: s2 = _mm256_fnmadd_pd(a5, w3, s2);
268: /* v += 9; */
269: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
270: a0 = _mm256_loadu_pd(&v[36]);
271: s0 = _mm256_fnmadd_pd(a0, w0, s0);
272: a1 = _mm256_loadu_pd(&v[40]);
273: s1 = _mm256_fnmadd_pd(a1, w0, s1);
274: a2 = _mm256_loadu_pd(&v[44]);
275: s2 = _mm256_fnmadd_pd(a2, w0, s2);
277: /* v += 9; */
278: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
279: a3 = _mm256_loadu_pd(&v[45]);
280: s0 = _mm256_fnmadd_pd(a3, w1, s0);
281: a4 = _mm256_loadu_pd(&v[49]);
282: s1 = _mm256_fnmadd_pd(a4, w1, s1);
283: a5 = _mm256_loadu_pd(&v[53]);
284: s2 = _mm256_fnmadd_pd(a5, w1, s2);
286: /* v += 9; */
287: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
288: a0 = _mm256_loadu_pd(&v[54]);
289: s0 = _mm256_fnmadd_pd(a0, w2, s0);
290: a1 = _mm256_loadu_pd(&v[58]);
291: s1 = _mm256_fnmadd_pd(a1, w2, s1);
292: a2 = _mm256_loadu_pd(&v[62]);
293: s2 = _mm256_fnmadd_pd(a2, w2, s2);
295: /* v += 9; */
296: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
297: a3 = _mm256_loadu_pd(&v[63]);
298: s0 = _mm256_fnmadd_pd(a3, w3, s0);
299: a4 = _mm256_loadu_pd(&v[67]);
300: s1 = _mm256_fnmadd_pd(a4, w3, s1);
301: a5 = _mm256_loadu_pd(&v[71]);
302: s2 = _mm256_fnmadd_pd(a5, w3, s2);
304: /* v += 9; */
305: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
306: a0 = _mm256_loadu_pd(&v[72]);
307: s0 = _mm256_fnmadd_pd(a0, w0, s0);
308: a1 = _mm256_loadu_pd(&v[76]);
309: s1 = _mm256_fnmadd_pd(a1, w0, s1);
310: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
311: s2 = _mm256_fnmadd_pd(a2, w0, s2);
312: v += bs2;
313: }
315: _mm256_storeu_pd(&ls[0], s0);
316: _mm256_storeu_pd(&ls[4], s1);
317: _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
319: w0 = _mm256_setzero_pd();
320: w1 = _mm256_setzero_pd();
321: w2 = _mm256_setzero_pd();
323: /* first row */
324: v0 = _mm256_set1_pd(ls[0]);
325: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[0]);
326: w0 = _mm256_fmadd_pd(a0, v0, w0);
327: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[4]);
328: w1 = _mm256_fmadd_pd(a1, v0, w1);
329: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[8]);
330: w2 = _mm256_fmadd_pd(a2, v0, w2);
332: /* second row */
333: v1 = _mm256_set1_pd(ls[1]);
334: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[9]);
335: w0 = _mm256_fmadd_pd(a3, v1, w0);
336: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[13]);
337: w1 = _mm256_fmadd_pd(a4, v1, w1);
338: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[17]);
339: w2 = _mm256_fmadd_pd(a5, v1, w2);
341: /* third row */
342: v2 = _mm256_set1_pd(ls[2]);
343: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[18]);
344: w0 = _mm256_fmadd_pd(a0, v2, w0);
345: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[22]);
346: w1 = _mm256_fmadd_pd(a1, v2, w1);
347: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[26]);
348: w2 = _mm256_fmadd_pd(a2, v2, w2);
350: /* fourth row */
351: v3 = _mm256_set1_pd(ls[3]);
352: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[27]);
353: w0 = _mm256_fmadd_pd(a3, v3, w0);
354: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[31]);
355: w1 = _mm256_fmadd_pd(a4, v3, w1);
356: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[35]);
357: w2 = _mm256_fmadd_pd(a5, v3, w2);
359: /* fifth row */
360: v0 = _mm256_set1_pd(ls[4]);
361: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[36]);
362: w0 = _mm256_fmadd_pd(a0, v0, w0);
363: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[40]);
364: w1 = _mm256_fmadd_pd(a1, v0, w1);
365: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[44]);
366: w2 = _mm256_fmadd_pd(a2, v0, w2);
368: /* sixth row */
369: v1 = _mm256_set1_pd(ls[5]);
370: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[45]);
371: w0 = _mm256_fmadd_pd(a3, v1, w0);
372: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[49]);
373: w1 = _mm256_fmadd_pd(a4, v1, w1);
374: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[53]);
375: w2 = _mm256_fmadd_pd(a5, v1, w2);
377: /* seventh row */
378: v2 = _mm256_set1_pd(ls[6]);
379: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[54]);
380: w0 = _mm256_fmadd_pd(a0, v2, w0);
381: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[58]);
382: w1 = _mm256_fmadd_pd(a1, v2, w1);
383: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[62]);
384: w2 = _mm256_fmadd_pd(a2, v2, w2);
386: /* eighth row */
387: v3 = _mm256_set1_pd(ls[7]);
388: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[63]);
389: w0 = _mm256_fmadd_pd(a3, v3, w0);
390: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[67]);
391: w1 = _mm256_fmadd_pd(a4, v3, w1);
392: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[71]);
393: w2 = _mm256_fmadd_pd(a5, v3, w2);
395: /* ninth row */
396: v0 = _mm256_set1_pd(ls[8]);
397: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[72]);
398: w0 = _mm256_fmadd_pd(a3, v0, w0);
399: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[76]);
400: w1 = _mm256_fmadd_pd(a4, v0, w1);
401: a2 = _mm256_maskload_pd((&(aa + bs2 * adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
402: w2 = _mm256_fmadd_pd(a2, v0, w2);
404: _mm256_storeu_pd(&(t + i * bs)[0], w0);
405: _mm256_storeu_pd(&(t + i * bs)[4], w1);
406: _mm256_maskstore_pd(&(t + i * bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), w2);
408: PetscArraycpy(x + i * bs, t + i * bs, bs);
409: }
411: VecRestoreArrayRead(bb, &b);
412: VecRestoreArray(xx, &x);
413: PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n);
414: return 0;
415: }
416: #endif