NAMD
AVXTilesKernel.h
Go to the documentation of this file.
1 #ifndef AVXTILESKERNEL_H
2 #define AVXTILESKERNEL_H
3 
4 #include "ComputeNonbondedUtil.h" // WMB: for KNL_TABLE_FACTOR
5 #include "AVXTileLists.h"
6 
7 #ifdef NAMD_AVXTILES
8 #include <immintrin.h>
9 
10 
11 FORCEINLINE void getOmpSimdTableI(const __m512 &r_1, __m512i &table_int,
12  __m512 &tableDiff, __m512 &rTableDiff) {
13  // table_r_1 = table_r_1 > KNL_TABLE_MAX_R_1 ? KNL_TABLE_MAX_R_1 : r_1
14  const __m512 maxv = _mm512_set1_ps(KNL_TABLE_MAX_R_1);
15  const __mmask16 tmask = _mm512_cmplt_ps_mask(maxv, r_1);
16  const __m512 table_r_1 = _mm512_mask_mov_ps(r_1, tmask, maxv);
17  // table_f = (KNL_TABLE_FACTOR-2) * knl_table_r_1;
18  const __m512 table_f = _mm512_mul_ps(_mm512_set1_ps(KNL_TABLE_FACTOR-2),
19  table_r_1);
20  // table_int = floor(table_f)
21  table_int = _mm512_cvttps_epi32(table_f);
22  // tableDiff = table_f - table_int
23  tableDiff = _mm512_sub_ps(table_f, _mm512_cvtepi32_ps(table_int));
24  // rTableDiff = 1. - tableDiff;
25  rTableDiff = _mm512_sub_ps(_mm512_set1_ps(1.f), tableDiff);
26 }
27 
28 // Full electrostatics portion
29 
30 template<bool doEnergy>
31 FORCEINLINE void forceEnergySlow512(const __mmask16 r2mask,
32  const __m512 &kqq,
33  const float * __restrict__ slowTable,
34  const float * __restrict__ slowEtable,
35  const __m512i &table_int,
36  const __m512 &tableDiff,
37  const __m512 &rTableDiff,
38  __m512 &forceSlow, __m512 &energySlow) {
39  // Load interpolation values
40  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
41  r2mask, table_int, slowTable, _MM_SCALE_4);
42  const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
43  const __mmask16 r2mask2 = r2mask >> 8;
44  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
45  r2mask2, table_int2, slowTable, _MM_SCALE_4);
46  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
47  15,13,11,9,7,5,3,1);
48  const __m512 tabSlowP1 = _mm512_permutex2var_ps(t0, t4, t1);
49  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
50  14,12,10,8,6,4,2,0);
51  const __m512 tabSlow = _mm512_permutex2var_ps(t0, t6, t1);
52 
53  // forceSlow = kqq * (-tabSlow * rTableDiff - tabSlowP1 * tableDiff)
54  forceSlow = _mm512_mul_ps(kqq, _mm512_fnmsub_ps(tabSlow, rTableDiff,
55  _mm512_mul_ps(tabSlowP1, tableDiff)));
56  if (doEnergy) {
57  // Load interpolation values for energy
58  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
59  _mm512_undefined_pd(), r2mask, table_int, slowEtable, _MM_SCALE_4);
60  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
61  _mm512_undefined_pd(), r2mask2, table_int2, slowEtable, _MM_SCALE_4);
62  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
63  15,13,11,9,7,5,3,1);
64  const __m512 tabSlowEp1 = _mm512_permutex2var_ps(t10, t4, t11);
65  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
66  14,12,10,8,6,4,2,0);
67  const __m512 tabSlowE = _mm512_permutex2var_ps(t10, t6, t11);
68 
69  // eSlow = tabSlowE * rTableDiff + tabSlowEp1 * tableDiff
70  const __m512 eSlow = _mm512_fmadd_ps(tabSlowE, rTableDiff,
71  _mm512_mul_ps(tabSlowEp1, tableDiff));
72  // energySlow = -kqq*eSlow + energySlow
73  energySlow = _mm512_mask_mov_ps(energySlow, r2mask,
74  _mm512_fnmadd_ps(kqq, eSlow, energySlow));
75  }
76 }
77 
78 // Interpolation mode 2 (c1 splitting) and 3 variants
79 
80 template<bool doEnergy, bool doSlow, int iMode>
81 FORCEINLINE void forceEnergyInterp2(const __m512 &r2, const __m512 &kqq,
82  const __m512i &type_i,
83  const __m512i &type_j, __m512 &force,
84  __m512 &forceSlow, __m512 &energyVdw,
85  __m512 &energyElec, __m512 &energySlow,
86  const __mmask16 r2mask,
87  const float scaling, const float c1,
88  const float c3, const float switchOn2,
89  const float cutoff2,
90  const float mInvCut3,
91  const float cutUnder3,
92  const float * __restrict__ fastTable,
93  const float * __restrict__ energyTable,
94  const float * __restrict__ slowTable,
95  const float * __restrict__ slowEtable,
96  const float * __restrict__ ljTable,
97  const int ljWidth) {
98 
99  // lj_i = (type_i*ljWidth + type_j) * 2
100  const __m512i lj_i = _mm512_slli_epi32(_mm512_add_epi32(
101  _mm512_mullo_epi32(type_i,_mm512_set1_epi32(ljWidth)), type_j), 1);
102  // Load A and B values from ljTable
103  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
104  r2mask, lj_i, ljTable, _MM_SCALE_8);
105  const __m512i lj_i2 = _mm512_shuffle_i32x4(lj_i, lj_i, 238);
106  const __mmask16 r2mask2 = r2mask >> 8;
107  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
108  r2mask2, lj_i2, ljTable, _MM_SCALE_8);
109  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
110  15,13,11,9,7,5,3,1);
111  const __m512 B = _mm512_permutex2var_ps(t0, t4, t1);
112  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
113  14,12,10,8,6,4,2,0);
114  const __m512 A = _mm512_permutex2var_ps(t0, t6, t1);
115 
116  // r_1 = 1./sqrt(r2)
117  const __m512 r_1 = _mm512_invsqrt_ps(r2);
118  __m512 tableDiff, rTableDiff;
119  __m512i table_int;
120  if (iMode == 3 || doSlow)
121  getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
122 
123  // Get interpolation values for fast
124  __m512 tabFast, tabFastp1, tabEnergy, tabEnergyp1;
125  __m512 tabSlow, tabSlowp1, tabSlowE, tabSlowEp1;
126  if (iMode == 3) {
127  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
128  r2mask, table_int, fastTable, _MM_SCALE_4);
129  const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
130  const __mmask16 r2mask2 = r2mask >> 8;
131  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
132  r2mask2, table_int2, fastTable, _MM_SCALE_4);
133  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
134  15,13,11,9,7,5,3,1);
135  tabFastp1 = _mm512_permutex2var_ps(t0, t4, t1);
136  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
137  14,12,10,8,6,4,2,0);
138  tabFast = _mm512_permutex2var_ps(t0, t6, t1);
139 
140  if (doEnergy) {
141  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
142  _mm512_undefined_pd(), r2mask, table_int, energyTable, _MM_SCALE_4);
143  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
144  _mm512_undefined_pd(), r2mask2, table_int2, energyTable, _MM_SCALE_4);
145  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
146  15,13,11,9,7,5,3,1);
147  tabEnergyp1 = _mm512_permutex2var_ps(t10, t4, t11);
148  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
149  14,12,10,8,6,4,2,0);
150  tabEnergy = _mm512_permutex2var_ps(t10, t6, t11);
151  }
152  }
153 
154  // r_2 = r_1 * r_1
155  const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
156  // r_6 = r_2 * r_2 * r_2
157  const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
158  // r_12 = r_6 * r_6;
159  const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
160  // c2 = cutoff2 - r2
161  const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
162  // c4 = (-2. * c2 + c3) * c2
163  const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
164  _mm512_set1_ps(c3)), c2);
165  // switchVal = r2 > switchOn2 ? c2*c4*c1 : 1.
166  const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
167  r2);
168  const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
169  _mm512_mul_ps(c2, _mm512_mul_ps(c4, _mm512_set1_ps(c1))));
170  // dSwitchVal = r2 > switchOn2 ? 2.*c1*(c2*c2 - c4) : 0.
171  const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
172  switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
173  _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
174  // r2SwitchVal = switchVal * r_2
175  const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
176  // vdwAgradient = (-6.*r2SwitchVal + dSwitchVal) * r_12
177  const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
178  _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
179  // vdwBgradient = (-3. * r2SwitchVal + dSwitchVal) * r_6
180  const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
181  _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
182  // vdwB = scaling * (A*vdwAgradient - B*vdwBgradient)
183  const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(scaling),
184  _mm512_fmsub_ps(A,vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
185 
186  __m512 ffast;
187  // ffast = kqq * (r_2*r_1+mInvCut3);
188  if (iMode == 2) ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2,r_1,
189  _mm512_set1_ps(mInvCut3)));
190  // ffast = kqq * (tabFast*rTableDiff+tabFastp1*tableDiff)
191  else ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(tabFast, rTableDiff,
192  _mm512_mul_ps(tabFastp1, tableDiff)));
193 
194  if (doEnergy) {
195  __m512 efast;
196  if (iMode == 2) {
197  // efast = r2*mInvCut3 + cutUnder3
198  efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
199  _mm512_set1_ps(cutUnder3));
200  // efast = efast*0.5f - r_1
201  efast = _mm512_fmsub_ps(efast,_mm512_set1_ps(0.5f), r_1);
202  // efast = tabEnergy*rTableDiff + tabEnergyp1*tableDiff
203  } else
204  efast = _mm512_fmadd_ps(tabEnergy, rTableDiff,
205  _mm512_mul_ps(tabEnergyp1, tableDiff));
206 
207  // vdwTerm = A*r_12 - B*r_6
208  const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
209  // energyVdw = switchVal * vdwTerm + energyVdw
210  energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
211  _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
212  // energyElec = -kqq*efast + energyElec
213  energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
214  _mm512_fnmadd_ps(kqq, efast, energyElec));
215  }
216 
217  // force = vdwB - ffast
218  force = _mm512_sub_ps(vdwB, ffast);
219  if (doSlow)
220  forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
221  tableDiff, rTableDiff, forceSlow, energySlow);
222 }
223 
224 // Interpolation mode 1 variant of "fast" calculations
225 // -- Supports arithmetic VDW mixing and C1 splitting
226 
227 template<bool doEnergy, bool doSlow>
228 FORCEINLINE void forceEnergyInterp1(const __m512 &r2, const __m512 &kqq,
229  __m512 &force, __m512 &forceSlow,
230  __m512 &energyVdw, __m512 &energyElec,
231  __m512 &energySlow,
232  const __mmask16 r2mask, const float c1,
233  const float c3, const float switchOn2,
234  const float cutoff2,
235  const float mInvCut3,
236  const float cutUnder3,
237  const float * __restrict__ slowTable,
238  const float * __restrict__ slowEtable,
239  const __m512 &eps4i, const __m512 &eps4j,
240  const __m512 &sigmaI,
241  const __m512 &sigmaJ) {
242 
243  // eps_ij = sqrt(eps4i * eps4j)
244  const __m512 eps_ij = _mm512_sqrt_ps(_mm512_mul_ps(eps4i, eps4j));
245  // sigma_ij = 0.5f * (sigmaI + sigmaJ)
246  __m512 sigma_ij = _mm512_mul_ps(_mm512_set1_ps(0.5f),
247  _mm512_add_ps(sigmaI, sigmaJ));
248  // sigma_ij *= sigma_ij * sigma_ij
249  sigma_ij = _mm512_mul_ps(sigma_ij, _mm512_mul_ps(sigma_ij, sigma_ij));
250  // sigma_ij *= sigma_ij
251  sigma_ij = _mm512_mul_ps(sigma_ij, sigma_ij);
252  // B = sigma_ij * eps_ij
253  const __m512 B(_mm512_mul_ps(sigma_ij, eps_ij));
254  // A = B * sigma_ij
255  const __m512 A(_mm512_mul_ps(B, sigma_ij));
256 
257  // r_1 = 1./sqrt(r2)
258  const __m512 r_1 = _mm512_invsqrt_ps(r2);
259  __m512 tableDiff, rTableDiff;;
260  __m512i table_int;
261  if (doSlow)
262  getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
263 
264  // r_2 = r_1 * r_1
265  const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
266  // r_6 = r_2 * r_2 * r_2
267  const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
268  // r_12 = r_6 * r_6
269  const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
270  // c2 = cutoff2 - r2
271  const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
272  // c4 = (-2. * c2 + c3) * c2
273  const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
274  _mm512_set1_ps(c3)), c2);
275  // switchVal = r2 > switchOn2 ? c2*c4*c1 : 1.
276  const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
277  r2);
278  const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
279  _mm512_mul_ps(c2,_mm512_mul_ps(c4,_mm512_set1_ps(c1))));
280  // dSwitchVal = r2 > switchOn2 ? 2.*c1*(c2*c2 - c4) : 0.
281  const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
282  switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
283  _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
284  // r2SwitchVal = switchVal * r_2
285  const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
286  // vdwAgradient = (-6.*r2SwitchVal + dSwitchVal) * r_12
287  const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
288  _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
289  // vdwBgradient = (-3. * r2SwitchVal + dSwitchVal) * r_6
290  const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
291  _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
292  // vdwB = scaling * (A*vdwAgradient - B*vdwBgradient)
293  const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(2.f),
294  _mm512_fmsub_ps(A, vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
295  // ffast = kqq * (r_2*r_1+mInvCut3);
296  const __m512 ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2, r_1,
297  _mm512_set1_ps(mInvCut3)));
298 
299  if (doEnergy) {
300  // efast = r2*mInvCut3 + cutUnder3
301  __m512 efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
302  _mm512_set1_ps(cutUnder3));
303  // efast = efast*0.5f - r_1
304  efast = _mm512_fmsub_ps(efast, _mm512_set1_ps(0.5f), r_1);
305  // vdwTerm = A*r_12 - B*r_6
306  const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
307  // energyVdw = switchVal * vdwTerm + energyVdw
308  energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
309  _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
310  // energyElec = -kqq*efast + energyElec
311  energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
312  _mm512_fnmadd_ps(kqq, efast, energyElec));
313  }
314 
315  // force = vdwB - ffast
316  force = _mm512_sub_ps(vdwB, ffast);
317  if (doSlow)
318  forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
319  tableDiff, rTableDiff, forceSlow, energySlow);
320 }
321 
322 #endif // NAMD_AVXTILES
323 #endif // AVXTILELISTS_H