NAMD
AVXTilesKernel.h
Go to the documentation of this file.
1 #ifndef AVXTILESKERNEL_H
2 #define AVXTILESKERNEL_H
3 
4 #include "ComputeNonbondedUtil.h" // WMB: for KNL_TABLE_FACTOR
5 
6 #ifdef NAMD_AVXTILES
7 #include <immintrin.h>
8 
9 __forceinline void getOmpSimdTableI(const __m512 &r_1, __m512i &table_int,
10  __m512 &tableDiff, __m512 &rTableDiff) {
11  // table_r_1 = table_r_1 > KNL_TABLE_MAX_R_1 ? KNL_TABLE_MAX_R_1 : r_1
12  const __m512 maxv = _mm512_set1_ps(KNL_TABLE_MAX_R_1);
13  const __mmask16 tmask = _mm512_cmplt_ps_mask(maxv, r_1);
14  const __m512 table_r_1 = _mm512_mask_mov_ps(r_1, tmask, maxv);
15  // table_f = (KNL_TABLE_FACTOR-2) * knl_table_r_1;
16  const __m512 table_f = _mm512_mul_ps(_mm512_set1_ps(KNL_TABLE_FACTOR-2),
17  table_r_1);
18  // table_int = floor(table_f)
19  table_int = _mm512_cvttps_epi32(table_f);
20  // tableDiff = table_f - table_int
21  tableDiff = _mm512_sub_ps(table_f, _mm512_cvtepi32_ps(table_int));
22  // rTableDiff = 1. - tableDiff;
23  rTableDiff = _mm512_sub_ps(_mm512_set1_ps(1.f), tableDiff);
24 }
25 
26 // Full electrostatics portion
27 
28 template<bool doEnergy>
29 __forceinline void forceEnergySlow512(const __mmask16 r2mask,
30  const __m512 &kqq,
31  const float * __restrict__ slowTable,
32  const float * __restrict__ slowEtable,
33  const __m512i &table_int,
34  const __m512 &tableDiff,
35  const __m512 &rTableDiff,
36  __m512 &forceSlow, __m512 &energySlow) {
37  // Load interpolation values
38  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
39  r2mask, table_int, slowTable, _MM_SCALE_4);
40  const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
41  const __mmask16 r2mask2 = r2mask >> 8;
42  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
43  r2mask2, table_int2, slowTable, _MM_SCALE_4);
44  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
45  15,13,11,9,7,5,3,1);
46  const __m512 tabSlowP1 = _mm512_permutex2var_ps(t0, t4, t1);
47  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
48  14,12,10,8,6,4,2,0);
49  const __m512 tabSlow = _mm512_permutex2var_ps(t0, t6, t1);
50 
51  // forceSlow = kqq * (-tabSlow * rTableDiff - tabSlowP1 * tableDiff)
52  forceSlow = _mm512_mul_ps(kqq, _mm512_fnmsub_ps(tabSlow, rTableDiff,
53  _mm512_mul_ps(tabSlowP1, tableDiff)));
54  if (doEnergy) {
55  // Load interpolation values for energy
56  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
57  _mm512_undefined_pd(), r2mask, table_int, slowEtable, _MM_SCALE_4);
58  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
59  _mm512_undefined_pd(), r2mask2, table_int2, slowEtable, _MM_SCALE_4);
60  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
61  15,13,11,9,7,5,3,1);
62  const __m512 tabSlowEp1 = _mm512_permutex2var_ps(t10, t4, t11);
63  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
64  14,12,10,8,6,4,2,0);
65  const __m512 tabSlowE = _mm512_permutex2var_ps(t10, t6, t11);
66 
67  // eSlow = tabSlowE * rTableDiff + tabSlowEp1 * tableDiff
68  const __m512 eSlow = _mm512_fmadd_ps(tabSlowE, rTableDiff,
69  _mm512_mul_ps(tabSlowEp1, tableDiff));
70  // energySlow = -kqq*eSlow + energySlow
71  energySlow = _mm512_mask_mov_ps(energySlow, r2mask,
72  _mm512_fnmadd_ps(kqq, eSlow, energySlow));
73  }
74 }
75 
76 // Interpolation mode 2 (c1 splitting) and 3 variants
77 
78 template<bool doEnergy, bool doSlow, int iMode>
79 __forceinline void forceEnergyInterp2(const __m512 &r2, const __m512 &kqq,
80  const __m512i &type_i,
81  const __m512i &type_j, __m512 &force,
82  __m512 &forceSlow, __m512 &energyVdw,
83  __m512 &energyElec, __m512 &energySlow,
84  const __mmask16 r2mask,
85  const float scaling, const float c1,
86  const float c3, const float switchOn2,
87  const float cutoff2,
88  const float mInvCut3,
89  const float cutUnder3,
90  const float * __restrict__ fastTable,
91  const float * __restrict__ energyTable,
92  const float * __restrict__ slowTable,
93  const float * __restrict__ slowEtable,
94  const float * __restrict__ ljTable,
95  const int ljWidth) {
96 
97  // lj_i = (type_i*ljWidth + type_j) * 2
98  const __m512i lj_i = _mm512_slli_epi32(_mm512_add_epi32(
99  _mm512_mullo_epi32(type_i,_mm512_set1_epi32(ljWidth)), type_j), 1);
100  // Load A and B values from ljTable
101  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
102  r2mask, lj_i, ljTable, _MM_SCALE_8);
103  const __m512i lj_i2 = _mm512_shuffle_i32x4(lj_i, lj_i, 238);
104  const __mmask16 r2mask2 = r2mask >> 8;
105  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
106  r2mask2, lj_i2, ljTable, _MM_SCALE_8);
107  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
108  15,13,11,9,7,5,3,1);
109  const __m512 B = _mm512_permutex2var_ps(t0, t4, t1);
110  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
111  14,12,10,8,6,4,2,0);
112  const __m512 A = _mm512_permutex2var_ps(t0, t6, t1);
113 
114  // r_1 = 1./sqrt(r2)
115  const __m512 r_1 = _mm512_invsqrt_ps(r2);
116  __m512 tableDiff, rTableDiff;
117  __m512i table_int;
118  if (iMode == 3 || doSlow)
119  getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
120 
121  // Get interpolation values for fast
122  __m512 tabFast, tabFastp1, tabEnergy, tabEnergyp1;
123  __m512 tabSlow, tabSlowp1, tabSlowE, tabSlowEp1;
124  if (iMode == 3) {
125  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
126  r2mask, table_int, fastTable, _MM_SCALE_4);
127  const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
128  const __mmask16 r2mask2 = r2mask >> 8;
129  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
130  r2mask2, table_int2, fastTable, _MM_SCALE_4);
131  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
132  15,13,11,9,7,5,3,1);
133  tabFastp1 = _mm512_permutex2var_ps(t0, t4, t1);
134  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
135  14,12,10,8,6,4,2,0);
136  tabFast = _mm512_permutex2var_ps(t0, t6, t1);
137 
138  if (doEnergy) {
139  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
140  _mm512_undefined_pd(), r2mask, table_int, energyTable, _MM_SCALE_4);
141  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
142  _mm512_undefined_pd(), r2mask2, table_int2, energyTable, _MM_SCALE_4);
143  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
144  15,13,11,9,7,5,3,1);
145  tabEnergyp1 = _mm512_permutex2var_ps(t10, t4, t11);
146  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
147  14,12,10,8,6,4,2,0);
148  tabEnergy = _mm512_permutex2var_ps(t10, t6, t11);
149  }
150  }
151 
152  // r_2 = r_1 * r_1
153  const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
154  // r_6 = r_2 * r_2 * r_2
155  const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
156  // r_12 = r_6 * r_6;
157  const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
158  // c2 = cutoff2 - r2
159  const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
160  // c4 = (-2. * c2 + c3) * c2
161  const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
162  _mm512_set1_ps(c3)), c2);
163  // switchVal = r2 > switchOn2 ? c2*c4*c1 : 1.
164  const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
165  r2);
166  const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
167  _mm512_mul_ps(c2, _mm512_mul_ps(c4, _mm512_set1_ps(c1))));
168  // dSwitchVal = r2 > switchOn2 ? 2.*c1*(c2*c2 - c4) : 0.
169  const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
170  switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
171  _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
172  // r2SwitchVal = switchVal * r_2
173  const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
174  // vdwAgradient = (-6.*r2SwitchVal + dSwitchVal) * r_12
175  const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
176  _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
177  // vdwBgradient = (-3. * r2SwitchVal + dSwitchVal) * r_6
178  const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
179  _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
180  // vdwB = scaling * (A*vdwAgradient - B*vdwBgradient)
181  const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(scaling),
182  _mm512_fmsub_ps(A,vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
183 
184  __m512 ffast;
185  // ffast = kqq * (r_2*r_1+mInvCut3);
186  if (iMode == 2) ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2,r_1,
187  _mm512_set1_ps(mInvCut3)));
188  // ffast = kqq * (tabFast*rTableDiff+tabFastp1*tableDiff)
189  else ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(tabFast, rTableDiff,
190  _mm512_mul_ps(tabFastp1, tableDiff)));
191 
192  if (doEnergy) {
193  __m512 efast;
194  if (iMode == 2) {
195  // efast = r2*mInvCut3 + cutUnder3
196  efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
197  _mm512_set1_ps(cutUnder3));
198  // efast = efast*0.5f - r_1
199  efast = _mm512_fmsub_ps(efast,_mm512_set1_ps(0.5f), r_1);
200  // efast = tabEnergy*rTableDiff + tabEnergyp1*tableDiff
201  } else
202  efast = _mm512_fmadd_ps(tabEnergy, rTableDiff,
203  _mm512_mul_ps(tabEnergyp1, tableDiff));
204 
205  // vdwTerm = A*r_12 - B*r_6
206  const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
207  // energyVdw = switchVal * vdwTerm + energyVdw
208  energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
209  _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
210  // energyElec = -kqq*efast + energyElec
211  energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
212  _mm512_fnmadd_ps(kqq, efast, energyElec));
213  }
214 
215  // force = vdwB - ffast
216  force = _mm512_sub_ps(vdwB, ffast);
217  if (doSlow)
218  forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
219  tableDiff, rTableDiff, forceSlow, energySlow);
220 }
221 
222 // Interpolation mode 1 variant of "fast" calculations
223 // -- Supports arithmetic VDW mixing and C1 splitting
224 
225 template<bool doEnergy, bool doSlow>
226 __forceinline void forceEnergyInterp1(const __m512 &r2, const __m512 &kqq,
227  __m512 &force, __m512 &forceSlow,
228  __m512 &energyVdw, __m512 &energyElec,
229  __m512 &energySlow,
230  const __mmask16 r2mask, const float c1,
231  const float c3, const float switchOn2,
232  const float cutoff2,
233  const float mInvCut3,
234  const float cutUnder3,
235  const float * __restrict__ slowTable,
236  const float * __restrict__ slowEtable,
237  const __m512 &eps4i, const __m512 &eps4j,
238  const __m512 &sigmaI,
239  const __m512 &sigmaJ) {
240 
241  // eps_ij = sqrt(eps4i * eps4j)
242  const __m512 eps_ij = _mm512_sqrt_ps(_mm512_mul_ps(eps4i, eps4j));
243  // sigma_ij = 0.5f * (sigmaI + sigmaJ)
244  __m512 sigma_ij = _mm512_mul_ps(_mm512_set1_ps(0.5f),
245  _mm512_add_ps(sigmaI, sigmaJ));
246  // sigma_ij *= sigma_ij * sigma_ij
247  sigma_ij = _mm512_mul_ps(sigma_ij, _mm512_mul_ps(sigma_ij, sigma_ij));
248  // sigma_ij *= sigma_ij
249  sigma_ij = _mm512_mul_ps(sigma_ij, sigma_ij);
250  // B = sigma_ij * eps_ij
251  const __m512 B(_mm512_mul_ps(sigma_ij, eps_ij));
252  // A = B * sigma_ij
253  const __m512 A(_mm512_mul_ps(B, sigma_ij));
254 
255  // r_1 = 1./sqrt(r2)
256  const __m512 r_1 = _mm512_invsqrt_ps(r2);
257  __m512 tableDiff, rTableDiff;;
258  __m512i table_int;
259  if (doSlow)
260  getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
261 
262  // r_2 = r_1 * r_1
263  const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
264  // r_6 = r_2 * r_2 * r_2
265  const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
266  // r_12 = r_6 * r_6
267  const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
268  // c2 = cutoff2 - r2
269  const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
270  // c4 = (-2. * c2 + c3) * c2
271  const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
272  _mm512_set1_ps(c3)), c2);
273  // switchVal = r2 > switchOn2 ? c2*c4*c1 : 1.
274  const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
275  r2);
276  const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
277  _mm512_mul_ps(c2,_mm512_mul_ps(c4,_mm512_set1_ps(c1))));
278  // dSwitchVal = r2 > switchOn2 ? 2.*c1*(c2*c2 - c4) : 0.
279  const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
280  switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
281  _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
282  // r2SwitchVal = switchVal * r_2
283  const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
284  // vdwAgradient = (-6.*r2SwitchVal + dSwitchVal) * r_12
285  const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
286  _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
287  // vdwBgradient = (-3. * r2SwitchVal + dSwitchVal) * r_6
288  const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
289  _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
290  // vdwB = scaling * (A*vdwAgradient - B*vdwBgradient)
291  const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(2.f),
292  _mm512_fmsub_ps(A, vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
293  // ffast = kqq * (r_2*r_1+mInvCut3);
294  const __m512 ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2, r_1,
295  _mm512_set1_ps(mInvCut3)));
296 
297  if (doEnergy) {
298  // efast = r2*mInvCut3 + cutUnder3
299  __m512 efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
300  _mm512_set1_ps(cutUnder3));
301  // efast = efast*0.5f - r_1
302  efast = _mm512_fmsub_ps(efast, _mm512_set1_ps(0.5f), r_1);
303  // vdwTerm = A*r_12 - B*r_6
304  const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
305  // energyVdw = switchVal * vdwTerm + energyVdw
306  energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
307  _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
308  // energyElec = -kqq*efast + energyElec
309  energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
310  _mm512_fnmadd_ps(kqq, efast, energyElec));
311  }
312 
313  // force = vdwB - ffast
314  force = _mm512_sub_ps(vdwB, ffast);
315  if (doSlow)
316  forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
317  tableDiff, rTableDiff, forceSlow, energySlow);
318 }
319 
320 #endif // NAMD_AVXTILES
321 #endif // AVXTILELISTS_H
const BigReal A
__global__ void const int const TileList *__restrict__ TileExcl *__restrict__ const int *__restrict__ const int const float2 *__restrict__ cudaTextureObject_t const int *__restrict__ const float3 const float3 const float3 const float4 *__restrict__ const float cutoff2
const BigReal B