NAMD
AVXTilesKernel.C
Go to the documentation of this file.
1 // -- WMB: For exclusions
2 #include "Node.h"
3 #include "Molecule.h"
4 // --
5 
6 #include "AVXTilesKernel.h"
7 
8 #ifdef NAMD_AVXTILES
9 
10 #define MAX(A,B) ((A) > (B) ? (A) : (B))
11 
12 #ifndef MEM_OPT_VERSION
13 const char * AVXTileLists::buildExclFlyList(const int itileList, const int z,
14  const __m512i &atomIndex_i,
15  const int n, void *molIn) {
16  if (itileList == _lastFlyListTile) return _exclFlyLists[z];
17 
18  Molecule *mol = (Molecule *)molIn;
19 
20  if (!_exclFlyListBuffer) {
21  _exclFlyListBuffer = new char[mol->numAtoms * 16];
22  for (int i = 0; i < 16; i++)
23  _exclFlyLists[i] = _exclFlyListBuffer + mol->numAtoms * i;
24  }
25  if (_lastFlyListTile == -1)
26  memset( (void*) _exclFlyListBuffer, 0, mol->numAtoms * 16);
27 
28  for (int i = 0; i < 16; i++) {
29  if (i >= n) break;
30  char *exclFlyList = _exclFlyLists[i];
31  const int32 *& fullExcl = _fullExcl[i];
32  const int32 *& modExcl = _modExcl[i];
33 
34  if (_lastFlyListTile != -1) {
35  int nl;
36  nl = fullExcl[0] + 1;
37  for (int l=1; l<nl; ++l ) exclFlyList[fullExcl[l]] = 0;
38  nl = modExcl[0] + 1;
39  for (int l=1; l<nl; ++l ) exclFlyList[modExcl[l]] = 0;
40  }
41 
42  int nl;
43  const int id = *((int*)&(atomIndex_i) + i);
44  fullExcl = mol->get_full_exclusions_for_atom(id);
45  nl = fullExcl[0] + 1;
46  for (int l=1; l<nl; ++l )
47  exclFlyList[fullExcl[l]] = EXCHCK_FULL;
48  modExcl = mol->get_mod_exclusions_for_atom(id);
49  nl = modExcl[0] + 1;
50  for (int l=1; l<nl; ++l )
51  exclFlyList[modExcl[l]] = EXCHCK_MOD;
52  } // for i
53 
54  _lastFlyListTile = itileList;
55  return _exclFlyLists[z];
56 }
57 #endif
58 
59 //---------------------------------------------------------------------------
60 // Calculations for unmodified/unexcluded from tile lists
61 //---------------------------------------------------------------------------
62 
63 template <bool doEnergy, bool doVirial, bool doSlow, bool doList,
64  int interpMode>
65 void AVXTileLists::nbAVX512Tiles(__m512 &energyVdw, __m512 &energyElec,
66  __m512 &energySlow, __m512 &fNet_x,
67  __m512 &fNet_y, __m512 &fNet_z,
68  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
69  __m512 &fNetSlow_z) {
70 
71  #ifndef MEM_OPT_VERSION
72  _lastFlyListTile = -1;
73  #endif
74 
75  Molecule* mol = Node::Object()->molecule;
76 
77  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_i = tiles_p0->atoms;
78  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_j = tiles_p1->atoms;
79  AVXTiles::AVXTilesForce* __restrict__ forces_i = tiles_p0->forces;
80  AVXTiles::AVXTilesForce* __restrict__ forces_j = tiles_p1->forces;
81  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_i;
82  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_j;
83  if (doSlow) {
84  forcesSlow_i = tiles_p0->forcesSlow;
85  forcesSlow_j = tiles_p1->forcesSlow;
86  }
87 
88  const float* __restrict__ fastTable;
89  const float* __restrict__ energyTable;
90  const float * __restrict__ ljTable;
91  const float * __restrict__ eps4sigma;
92  // Interpolation for long-range splitting function
93  if (interpMode == 3) {
94  fastTable = _paramFastTable;
95  energyTable = _paramFastEnergyTable;
96  }
97  // LJ mixing not performed within kernel
98  if (interpMode > 1) ljTable = _paramLjTable;
99  // LJ mixing performed within kernel
100  if (interpMode == 1) eps4sigma = _paramEps4Sigma;
101 
102  const float* __restrict__ slowTable = _paramSlowTable;
103  const float* __restrict__ slowEnergyTable = _paramSlowEnergyTable;
104 
105  #ifdef TILE_LIST_STAT_DEBUG
106  int num_jtiles = 0;
107  int num_jtiles_empty = 0;
108  int num_rotates = 0;
109  int num_rotates_empty = 0;
110  int num_rotates_excl_empty = 0;
111  #endif
112 
113 
114  int numModified, numExcluded;
115  if (doList) {
116  numModified = 0;
117  // WMB: Only need for slow electrostatics
118  numExcluded = 0;
119  }
120 
121  const bool zeroShift = ! (_shx*_shx + _shy*_shy + _shz*_shz > 0) &&
122  (tiles_p0 == tiles_p1);
123 
124  const int numTileLists = numLists();
125  for (int itileList = 0; itileList < numTileLists; itileList++) {
126 
127  bool jTileActive;
128  unsigned int itileListLen;
129 
130  const int atomStart_i = lists[itileList].atomStart_i;
131  const int jtileStart = lists[itileList].jtileStart;
132  const int jtileEnd = lists[itileList + 1].jtileStart;
133 
134  int atomSize_i, atomFreeSize_i, atomSize_j, atomFreeSize_j;
135  int nFree_i;
136  __mmask16 freeMask_i, sizeMask_i;
137  bool doIpairList;
138  float bbox_x, bbox_y, bbox_z, bbox_wx, bbox_wy, bbox_wz;
139  __m512i atomIndex_i;
140 #ifdef MEM_OPT_VERSION
141  __m512i atomExclIndex_i;
142 #endif
143  __m512i exclIndex_i, exclMaxDiff_i;
144  if (doList) {
145  atomSize_i = tiles_p0->numAtoms();
146  atomFreeSize_i = tiles_p0->numFreeAtoms();
147  atomSize_j = tiles_p1->numAtoms();
148  atomFreeSize_j = tiles_p1->numFreeAtoms();
149 
150  freeMask_i = 0xFFFF;
151  nFree_i = MAX(atomFreeSize_i - atomStart_i, 0);
152  if (nFree_i < 16)
153  freeMask_i >>= 16 - nFree_i;
154  sizeMask_i = 0xFFFF;
155  const int check_i = atomSize_i - atomStart_i;
156  if (check_i < 16) {
157  sizeMask_i >>= 16 - check_i;
158  if (NAMD_AVXTILES_PAIR_THRESHOLD > 0) {
159  if (check_i <= NAMD_AVXTILES_PAIR_THRESHOLD) doIpairList = true;
160  else doIpairList = false;
161  }
162  } else if (NAMD_AVXTILES_PAIR_THRESHOLD > 0)
163  doIpairList = false;
164 
165  const int tileNum = atomStart_i >> 4;
166  bbox_x = tiles_p0->bbox_x[tileNum] + _shx;
167  bbox_y = tiles_p0->bbox_y[tileNum] + _shy;
168  bbox_z = tiles_p0->bbox_z[tileNum] + _shz;
169  bbox_wx = tiles_p0->bbox_wx[tileNum];
170  bbox_wy = tiles_p0->bbox_wy[tileNum];
171  bbox_wz = tiles_p0->bbox_wz[tileNum];
172 
173  itileListLen = 0;
174  atomIndex_i = _mm512_loadu_epi32(tiles_p0->atomIndex + atomStart_i);
175 #ifdef MEM_OPT_VERSION
176  atomExclIndex_i = _mm512_loadu_epi32(tiles_p0->atomExclIndex +
177  atomStart_i);
178 #endif
179  exclIndex_i = _mm512_loadu_epi32(tiles_p0->exclIndexStart + atomStart_i);
180  exclMaxDiff_i = _mm512_loadu_epi32(tiles_p0->exclIndexMaxDiff +
181  atomStart_i);
182  }
183 
184  // Load i-atom data (and shift coordinates)
185  const float *iptr = (float *)(xyzq_i + atomStart_i);
186  const __m512 t0 = _mm512_loadu_ps(iptr);
187  const __m512 t1 = _mm512_loadu_ps(iptr+16);
188  const __m512 t2 = _mm512_loadu_ps(iptr+32);
189  const __m512 t3 = _mm512_loadu_ps(iptr+48);
190  const __m512i t4 = _mm512_set_epi32(29,25,21,17,13,9,5,1,
191  28,24,20,16,12,8,4,0);
192  const __m512 t5 = _mm512_permutex2var_ps(t0, t4, t1);
193  const __m512i t6 = _mm512_set_epi32(28,24,20,16,12,8,4,0,
194  29,25,21,17,13,9,5,1);
195  const __m512 t7 = _mm512_permutex2var_ps(t2, t6, t3);
196  const __mmask16 t9 = 0xFF00;
197  const __m512i t5i = _mm512_castps_si512(t5);
198  const __m512i t7i = _mm512_castps_si512(t7);
199  const __m512 x_i = _mm512_add_ps(_mm512_castsi512_ps(
200  _mm512_mask_blend_epi32(t9, t5i, t7i)),_mm512_set1_ps(_shx));
201  const __m512 y_i = _mm512_add_ps(_mm512_shuffle_f32x4(t5, t7, 0x4E),
202  _mm512_set1_ps(_shy));
203  const __m512i t12 = _mm512_set_epi32(31,27,23,19,15,11,7,3,
204  30,26,22,18,14,10,6,2);
205  const __m512 t13 = _mm512_permutex2var_ps(t0, t12, t1);
206  const __m512i t14 = _mm512_set_epi32(30,26,22,18,14,10,6,2,
207  31,27,23,19,15,11,7,3);
208  const __m512 t15 = _mm512_permutex2var_ps(t2, t14, t3);
209  const __m512i t13i = _mm512_castps_si512(t13);
210  const __m512i t15i = _mm512_castps_si512(t15);
211  const __m512 z_i = _mm512_add_ps(_mm512_castsi512_ps(
212  _mm512_mask_blend_epi32(t9, t13i, t15i)), _mm512_set1_ps(_shz));
213  const __m512 q_i = _mm512_castsi512_ps(_mm512_shuffle_i32x4(t13i, t15i,
214  0x4E));
215  const __m512i type_i = _mm512_loadu_epi32(tiles_p0->vdwTypes+atomStart_i);
216 
217  // WMB: This can be masked by sizeMask_i; currently only get for doList
218  __m512 eps4_i, sigma_i;
219  if (interpMode == 1) {
220  const __m512 t0 = (__m512)_mm512_i32logather_pd(type_i, eps4sigma,
221  _MM_SCALE_8);
222  const __m512i type_i2 = _mm512_shuffle_i32x4(type_i, type_i, 238);
223  const __m512 t1 = (__m512)_mm512_i32logather_pd(type_i2, eps4sigma,
224  _MM_SCALE_8);
225  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
226  15,13,11,9,7,5,3,1);
227  sigma_i = _mm512_permutex2var_ps(t0, t4, t1);
228  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
229  14,12,10,8,6,4,2,0);
230  eps4_i = _mm512_permutex2var_ps(t0, t6, t1);
231  }
232 
233  // Zero i-tile force vectors
234  __m512 forceSlow_i_x, forceSlow_i_y, forceSlow_i_z;
235  __m512 force_i_x = _mm512_setzero_ps();
236  __m512 force_i_y = _mm512_setzero_ps();
237  __m512 force_i_z = _mm512_setzero_ps();
238  if (doSlow) {
239  forceSlow_i_x = _mm512_setzero_ps();
240  forceSlow_i_y = _mm512_setzero_ps();
241  forceSlow_i_z = _mm512_setzero_ps();
242  }
243 
244  for (int jtile=jtileStart;jtile < jtileEnd;jtile++) {
245 
246  #ifdef TILE_LIST_STAT_DEBUG
247  num_jtiles++;
248  bool jtile_empty = true;
249  #endif
250 
251  // Load j-atom starting index
252  int atomStart_j = jTiles.atomStart[jtile];
253  const bool self = zeroShift && (atomStart_i == atomStart_j);
254  const int shiftVal = (self) ? 1 : 0;
255 
256  // Load j-atom positions / charges
257  const float * jptr = (float *)(xyzq_j + atomStart_j);
258  const __m512 t0 = _mm512_loadu_ps(jptr);
259  const __m512 t1 = _mm512_loadu_ps(jptr+16);
260  const __m512 t2 = _mm512_loadu_ps(jptr+32);
261  const __m512 t3 = _mm512_loadu_ps(jptr+48);
262  const __m512i t4 = _mm512_set_epi32(29,25,21,17,13,9,5,1,
263  28,24,20,16,12,8,4,0);
264  const __m512 t5 = _mm512_permutex2var_ps(t0, t4, t1);
265  const __m512i t6 = _mm512_set_epi32(28,24,20,16,12,8,4,0,
266  29,25,21,17,13,9,5,1);
267  const __m512 t7 = _mm512_permutex2var_ps(t2, t6, t3);
268  const __mmask16 t9 = 0xFF00;
269  const __m512i t5i = _mm512_castps_si512(t5);
270  const __m512i t7i = _mm512_castps_si512(t7);
271  __m512 x_j = _mm512_castsi512_ps(_mm512_mask_blend_epi32(t9, t5i, t7i));
272  __m512 y_j = _mm512_shuffle_f32x4(t5, t7, 0x4E);
273  const __m512i t12 = _mm512_set_epi32(31,27,23,19,15,11,7,3,
274  30,26,22,18,14,10,6,2);
275  const __m512 t13 = _mm512_permutex2var_ps(t0, t12, t1);
276  const __m512i t14 = _mm512_set_epi32(30,26,22,18,14,10,6,2,
277  31,27,23,19,15,11,7,3);
278  const __m512 t15 = _mm512_permutex2var_ps(t2, t14, t3);
279  const __m512i t13i = _mm512_castps_si512(t13);
280  const __m512i t15i = _mm512_castps_si512(t15);
281  __m512 z_j = _mm512_castsi512_ps(_mm512_mask_blend_epi32(t9, t13i,
282  t15i));
283  __m512 q_j = _mm512_castsi512_ps(_mm512_shuffle_i32x4(t13i, t15i, 0x4E));
284 
285  __m512i excl, atomIndex_j;
286  __mmask16 freeMask_j, sizeMask_j;
287  if (doList) {
288  // Check for early bail from i-tile bounding box
289 
290  // dx = max(0.f, abs(bbox_x - x_j) - bbox_wx)
291  const __m512 dx_one = _mm512_abs_ps(
292  _mm512_sub_ps(_mm512_set1_ps(bbox_x), x_j));
293  const __m512 dx_two = _mm512_set1_ps(bbox_wx);
294  const __mmask16 lxmask = _mm512_cmplt_ps_mask(dx_two, dx_one);
295  const __m512 dx = _mm512_mask_sub_ps(_mm512_setzero_ps(), lxmask,
296  dx_one, dx_two);
297  // dy = max(0.f, abs(bbox_y - y_j) - bbox_wy)
298  const __m512 dy_one = _mm512_abs_ps(
299  _mm512_sub_ps(_mm512_set1_ps(bbox_y), y_j));
300  const __m512 dy_two = _mm512_set1_ps(bbox_wy);
301  const __mmask16 lymask = _mm512_cmplt_ps_mask(dy_two, dy_one);
302  const __m512 dy = _mm512_mask_sub_ps(_mm512_setzero_ps(), lymask,
303  dy_one, dy_two);
304  // dz = max(0.f, abs(bbox_z - z_j) - bbox_wz)
305  const __m512 dz_one = _mm512_abs_ps(
306  _mm512_sub_ps(_mm512_set1_ps(bbox_z), z_j));
307  const __m512 dz_two = _mm512_set1_ps(bbox_wz);
308  const __mmask16 lzmask = _mm512_cmplt_ps_mask(dz_two, dz_one);
309  const __m512 dz = _mm512_mask_sub_ps(_mm512_setzero_ps(), lzmask,
310  dz_one, dz_two);
311  // r2bb = dx*dx + dy*dy + dz*dz
312  const __m512 r2bb = _mm512_fmadd_ps(dx,dx,_mm512_fmadd_ps(dy,dy,
313  _mm512_mul_ps(dz, dz)));
314 
315  // If no atoms within bounding box, skip this neighbor tile
316  __mmask16 m = _mm512_cmple_ps_mask(r2bb, _mm512_set1_ps(_plcutoff2));
317  if (!m) continue;
318 
319  // Load atom indices
320  atomIndex_j = _mm512_loadu_epi32(tiles_p1->atomIndex+atomStart_j);
321  // Zero exclusion data
322  excl = _mm512_setzero_epi32();
323 
324  // Predication for number of j atoms and free atoms in tile
325  freeMask_j = 0xFFFF;
326  int nFree_j = MAX(atomFreeSize_j - atomStart_j, 0);
327  if (nFree_j < 16)
328  freeMask_j >>= 16 - nFree_j;
329  const int check_j = atomSize_j - atomStart_j;
330  sizeMask_j = 0xFFFF;
331  if (check_j < 16)
332  sizeMask_j >>= 16 - check_j;
333 
334  jTileActive = false;
335 
336  // If doing a pair list instead of tile list, make sure we have
337  // room to store neighbor atom indices
338  if (NAMD_AVXTILES_PAIR_THRESHOLD > 0 && doIpairList) {
339  int maxPairs = _numPairs[0];
340  for (int z = 1; z < NAMD_AVXTILES_PAIR_THRESHOLD; z++)
341  if (_numPairs[z] > maxPairs) maxPairs = _numPairs[z];
342  if (maxPairs + 16 > _maxPairs) {
343  reallocPairLists(0, 1.4 * _maxPairs);
344  }
345  }
346  } else
347  // Load tile exclusion data
348  excl = _mm512_loadu_epi32(jTiles.excl + (jtile << 4));
349 
350  // Load j tile LJ types and epsilon/sigma for interp mode 1
351  __m512i type_j = _mm512_loadu_epi32(tiles_p1->vdwTypes+atomStart_j);
352 
353  __m512 eps4_j, sigma_j;
354  if (interpMode == 1) {
355  const __m512 t0 = (__m512)_mm512_i32logather_pd(type_j, eps4sigma,
356  _MM_SCALE_8);
357  const __m512i type_j2 = _mm512_shuffle_i32x4(type_j, type_j, 238);
358  const __m512 t1 = (__m512)_mm512_i32logather_pd(type_j2, eps4sigma,
359  _MM_SCALE_8);
360  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
361  15,13,11,9,7,5,3,1);
362  sigma_j = _mm512_permutex2var_ps(t0, t4, t1);
363  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
364  14,12,10,8,6,4,2,0);
365  eps4_j = _mm512_permutex2var_ps(t0, t6, t1);
366  }
367 
368  // Zero force vectors for j tile
369  __m512 force_j_x = _mm512_setzero_ps();
370  __m512 force_j_y = _mm512_setzero_ps();
371  __m512 force_j_z = _mm512_setzero_ps();
372 
373  __m512 forceSlow_j_x, forceSlow_j_y, forceSlow_j_z;
374  if (doSlow) {
375  forceSlow_j_x = _mm512_setzero_ps();
376  forceSlow_j_y = _mm512_setzero_ps();
377  forceSlow_j_z = _mm512_setzero_ps();
378  }
379 
380  int t = (self) ? 1 : 0;
381 
382  // WMB: If adding GBIS p 2, need to add check diagonal for excluded
383 
384  // Predication for self-interactions to prevent duplicate interactions
385  __mmask16 mask_j;
386  if (doList) mask_j = 0xFFFF;
387 
388  // If doList, excl is 0; if self, skipping first bit
389  if (!doList && !self) excl = _mm512_slli_epi32(excl, 1);
390 
391  // Skip self interaction
392  if (self) {
393  x_j = (__m512)_mm512_alignr_epi32((__m512i)x_j, (__m512i)x_j, 1);
394  y_j = (__m512)_mm512_alignr_epi32((__m512i)y_j, (__m512i)y_j, 1);
395  z_j = (__m512)_mm512_alignr_epi32((__m512i)z_j, (__m512i)z_j, 1);
396  q_j = (__m512)_mm512_alignr_epi32((__m512i)q_j, (__m512i)q_j, 1);
397  if (doList) {
398  freeMask_j = (freeMask_j>>1) | (freeMask_j<<15);
399  sizeMask_j = (sizeMask_j>>1) | (sizeMask_j<<15);
400  if (doList) mask_j >>= shiftVal;
401  atomIndex_j = _mm512_alignr_epi32(atomIndex_j, atomIndex_j, 1);
402  }
403  if (interpMode == 1) {
404  eps4_j = (__m512)_mm512_alignr_epi32((__m512i)eps4_j,
405  (__m512i)eps4_j, 1);
406  sigma_j = (__m512)_mm512_alignr_epi32((__m512i)sigma_j,
407  (__m512i)sigma_j, 1);
408  } else
409  type_j = _mm512_alignr_epi32(type_j, type_j, 1);
410  // force is zero, no need to rotate
411  }
412 
413  for (; t < 16; t++) {
414  // Exclusion bit for next pair
415  excl = _mm512_srli_epi32(excl, 1);
416  __mmask16 freeMask_ij;
417  if (doList)
418  // Predication for partial tiles and fixed atoms at ends
419  freeMask_ij = (freeMask_j | freeMask_i) & (sizeMask_j & sizeMask_i);
420  else {
421  #ifdef TILE_LIST_STAT_DEBUG
422  __mmask16 r2mask = (excl & 1).notzero();
423  num_rotates++;
424  if (!r2mask) {num_rotates_empty++; num_rotates_excl_empty++;}
425  #endif
426  }
427 
428  // ------------- Distance squared calculation and cutoff check
429  __m512 dx, dy, dz, r2;
430  __mmask16 r2mask;
431  if (!doList || freeMask_ij) {
432  // dx = x_j - x_i;
433  dx = _mm512_sub_ps(x_j, x_i);
434  // dy = y_j - y_i;
435  dy = _mm512_sub_ps(y_j, y_i);
436  // dz = z_j - z_i;
437  dz = _mm512_sub_ps(z_j, z_i);
438  // r2 = dx*dx + dy*dy + dz*dz;
439  r2 = _mm512_fmadd_ps(dx,dx,_mm512_fmadd_ps(dy,dy,
440  _mm512_mul_ps(dz, dz)));
441 
442  #ifdef TILE_LIST_STAT_DEBUG
443  if (!doList) {
444  __mmask16 t=r2 <= _plcutoff2; t &= r2mask;
445  if (!t) num_rotates_empty++; else jtile_empty = false;
446  }
447  #endif
448 
449  if (doList) {
450  // Predication for atom pairs within build cutoff
451  r2mask = _mm512_cmple_ps_mask(r2, _mm512_set1_ps(_plcutoff2)) &
452  mask_j;
453  r2mask &= freeMask_ij;
454  } else {
455  // Predication for atom pairs within force cutoff
456  r2mask = _mm512_cmple_ps_mask(r2, _mm512_set1_ps(_cutoff2));
457  r2mask &= _mm512_cmpneq_epi32_mask(_mm512_and_epi32(excl,
458  _mm512_set1_epi32(1)), _mm512_setzero_epi32());
459  }
460  } else
461  r2mask = 0;
462 
463  // ------------- Exclusions
464  if (doList && r2mask) {
465  if (_numModifiedAlloc - numModified < 16) {
466  int newNum = static_cast<double>(numModified) *
467  numTileLists / itileList;
468  if (newNum < _numModifiedAlloc + 16)
469  newNum = _numModifiedAlloc + 16;
470  reallocModified(newNum);
471  }
472  if (_numExcludedAlloc - numExcluded < 16) {
473  int newNum = static_cast<double>(numExcluded) *
474  numTileLists / itileList;
475  if (newNum < _numExcludedAlloc + 16)
476  newNum = _numExcludedAlloc + 16;
477  reallocExcluded(newNum);
478  }
479 
480  // Predication for exclusions and cutoff
481  __mmask16 excludedMask = _mm512_cmpge_epi32_mask(atomIndex_j,
482  exclIndex_i);
483  excludedMask &= _mm512_cmple_epi32_mask(atomIndex_j, exclMaxDiff_i);
484  excludedMask &= r2mask;
485  __mmask16 scalarPos = 1;
486 
487  // Check each pair for exclusions
488  for (int z = 0; z < 16; z++) {
489  if (scalarPos & excludedMask) {
490  #ifdef MEM_OPT_VERSION
491  const char *exclFlags =
492  mol->get_excl_check_for_idx(*((int*)&(atomExclIndex_i) + z))->
493  flags;
494  #else
495  const char *exclFlags =
496  mol->get_excl_check_for_atom(*((int*)&(atomIndex_i) + z))->
497  flags;
498  #endif
499  #ifndef MEM_OPT_VERSION
500  if (exclFlags)
501  #endif
502  exclFlags -= *((int*)&(exclIndex_i) + z);
503  #ifndef MEM_OPT_VERSION
504  else {
505  #pragma noinline
506  exclFlags = buildExclFlyList(itileList, z, atomIndex_i,
507  atomSize_i - atomStart_i, mol);
508  }
509  #endif
510  const int exclType = exclFlags[*((int*)&atomIndex_j + z)];
511  if (exclType == 0)
512  // Clear exclusion bit
513  excludedMask &= ~scalarPos;
514  else if (exclType == 2) {
515  // Exclude from tiles and add pair to modified list
516  _modified_i[numModified] = atomStart_i + z;
517  int jpos = (z + t) & 15;
518  _modified_j[numModified] = atomStart_j + jpos;
519  if (*((float*)&r2 + z) <= _cutoff2)
520  numModified++;
521  } else {
522  // Excluded from tiles and add pair to fully excluded list
523  _excluded_i[numExcluded] = atomStart_i + z;
524  int jpos = (z + t) & 15;
525  _excluded_j[numExcluded] = atomStart_j + jpos;
526  if (*((float*)&r2 + z) <= _cutoff2)
527  numExcluded++;
528  }
529  }
530  // Next pair
531  scalarPos <<= 1;
532  }
533 
534  // Note: For exclusions, use force cutoff and not list cutoff
535  r2mask &= ~excludedMask;
536 
537  // If building a pair list, store each pair within cutoff
538  if (NAMD_AVXTILES_PAIR_THRESHOLD > 0 &&
539  doIpairList && r2mask) {
540  for (int z = 0; z < NAMD_AVXTILES_PAIR_THRESHOLD; z++) {
541  if (r2mask & 1) {
542  const int jPos = atomStart_j + ((z+t) & 15);
543  _pairLists[_pairStart[z]+_numPairs[z]] = jPos;
544  _numPairs[z]++;
545  }
546  r2mask >>= 1;
547  }
548  r2mask = 0;
549  } else {
550  // Mark tile list as active if neighbors within cutoff
551  excl = _mm512_mask_or_epi32(excl, r2mask, excl,
552  _mm512_set1_epi32(0x8000));
553  if (r2mask) jTileActive = true;
554  // Redo predication for pairs w/in force cutoff rather build cutoff
555  r2mask = _mm512_cmple_ps_mask(r2, _mm512_set1_ps(_cutoff2)) &
556  mask_j & freeMask_ij & ~excludedMask;
557  }
558  } // if doList
559 
560  // ------------- Force, Energy, Virial Calculations
561  if (r2mask) {
562  // kqq = q_i * q_j
563  const __m512 kqq = _mm512_mul_ps(q_i, q_j);
564  __m512 force, forceSlow;
565 
566  // Call force kernel
567  if (interpMode == 1)
568  forceEnergyInterp1<doEnergy, doSlow>(r2, kqq, force, forceSlow,
569  energyVdw, energyElec, energySlow, r2mask, _paramC1,
570  _paramC3, _paramSwitchOn2, _cutoff2, _paramMinvCut3,
571  _paramCutUnder3, slowTable, slowEnergyTable, eps4_i, eps4_j,
572  sigma_i, sigma_j);
573  else
574  forceEnergyInterp2<doEnergy, doSlow, interpMode>(r2, kqq, type_i,
575  type_j, force, forceSlow, energyVdw, energyElec, energySlow,
576  r2mask, _paramScale, _paramC1, _paramC3, _paramSwitchOn2,
577  _cutoff2, _paramMinvCut3, _paramCutUnder3, fastTable,
578  energyTable, slowTable, slowEnergyTable, ljTable,
579  _paramLjWidth);
580 
581  // force_i_. += d. * force
582  force_i_x = _mm512_mask_mov_ps(force_i_x, r2mask,
583  _mm512_fmadd_ps(dx, force, force_i_x));
584  force_i_y = _mm512_mask_mov_ps(force_i_y, r2mask,
585  _mm512_fmadd_ps(dy, force, force_i_y));
586  force_i_z = _mm512_mask_mov_ps(force_i_z, r2mask,
587  _mm512_fmadd_ps(dz, force, force_i_z));
588  // force_j_. += d. * force
589  force_j_x = _mm512_mask_mov_ps(force_j_x, r2mask,
590  _mm512_fnmadd_ps(dx, force, force_j_x));
591  force_j_y = _mm512_mask_mov_ps(force_j_y, r2mask,
592  _mm512_fnmadd_ps(dy, force, force_j_y));
593  force_j_z = _mm512_mask_mov_ps(force_j_z, r2mask,
594  _mm512_fnmadd_ps(dz, force, force_j_z));
595  if (doSlow) {
596  // force_i_. += d. * forceSlow_i_.
597  forceSlow_i_x = _mm512_mask_mov_ps(forceSlow_i_x, r2mask,
598  _mm512_fmadd_ps(dx, forceSlow, forceSlow_i_x));
599  forceSlow_i_y = _mm512_mask_mov_ps(forceSlow_i_y, r2mask,
600  _mm512_fmadd_ps(dy, forceSlow, forceSlow_i_y));
601  forceSlow_i_z = _mm512_mask_mov_ps(forceSlow_i_z, r2mask,
602  _mm512_fmadd_ps(dz, forceSlow, forceSlow_i_z));
603  // force_j_. += d. * forceSlow_j_.
604  forceSlow_j_x = _mm512_mask_mov_ps(forceSlow_j_x, r2mask,
605  _mm512_fnmadd_ps(dx, forceSlow, forceSlow_j_x));
606  forceSlow_j_y = _mm512_mask_mov_ps(forceSlow_j_y, r2mask,
607  _mm512_fnmadd_ps(dy, forceSlow, forceSlow_j_y));
608  forceSlow_j_z = _mm512_mask_mov_ps(forceSlow_j_z, r2mask,
609  _mm512_fnmadd_ps(dz, forceSlow, forceSlow_j_z));
610  }
611  }
612 
613  // ------------- Next set of pairs for tile interactions
614  x_j = (__m512)_mm512_alignr_epi32((__m512i)x_j, (__m512i)x_j, 1);
615  y_j = (__m512)_mm512_alignr_epi32((__m512i)y_j, (__m512i)y_j, 1);
616  z_j = (__m512)_mm512_alignr_epi32((__m512i)z_j, (__m512i)z_j, 1);
617  q_j = (__m512)_mm512_alignr_epi32((__m512i)q_j, (__m512i)q_j, 1);
618  if (doList) {
619  freeMask_j = (freeMask_j>>1) | (freeMask_j<<15);
620  sizeMask_j = (sizeMask_j>>1) | (sizeMask_j<<15);
621  mask_j >>= shiftVal;
622  atomIndex_j = _mm512_alignr_epi32(atomIndex_j, atomIndex_j, 1);
623  }
624  if (interpMode == 1) {
625  eps4_j = (__m512)_mm512_alignr_epi32((__m512i)eps4_j,
626  (__m512i)eps4_j, 1);
627  sigma_j = (__m512)_mm512_alignr_epi32((__m512i)sigma_j,
628  (__m512i)sigma_j, 1);
629  } else
630  type_j = _mm512_alignr_epi32(type_j, type_j, 1);
631  force_j_x = (__m512)_mm512_alignr_epi32((__m512i)force_j_x,
632  (__m512i)force_j_x, 1);
633  force_j_y = (__m512)_mm512_alignr_epi32((__m512i)force_j_y,
634  (__m512i)force_j_y, 1);
635  force_j_z = (__m512)_mm512_alignr_epi32((__m512i)force_j_z,
636  (__m512i)force_j_z, 1);
637  if (doSlow) {
638  forceSlow_j_x = (__m512)_mm512_alignr_epi32((__m512i)forceSlow_j_x,
639  (__m512i)forceSlow_j_x,
640  1);
641  forceSlow_j_y = (__m512)_mm512_alignr_epi32((__m512i)forceSlow_j_y,
642  (__m512i)forceSlow_j_y,
643  1);
644  forceSlow_j_z = (__m512)_mm512_alignr_epi32((__m512i)forceSlow_j_z,
645  (__m512i)forceSlow_j_z,
646  1);
647  }
648  } // for t
649 
650  // ------------- Accumulate j-forces in memory
651  const __m512i tp0x = _mm512_set_epi32(0,19,11,3,0,18,10,2,
652  0,17,9,1,0,16,8,0);
653  const __m512i tp1x = _mm512_set_epi32(0,23,15,7,0,22,14,6,0,21,
654  13,5,0,20,12,4);
655  const __m512i tp2x = _mm512_set_epi32(0,27,11,3,0,26,10,2,
656  0,25,9,1,0,24,8,0);
657  const __m512i tp3x = _mm512_set_epi32(0,31,15,7,0,30,14,6,
658  0,29,13,5,0,28,12,4);
659  {
660  float * jptr = (float *)(forces_j + atomStart_j);
661  const __m512 v0 = _mm512_loadu_ps(jptr);
662  const __m512 v1 = _mm512_loadu_ps(jptr + 16);
663  const __m512 v2 = _mm512_loadu_ps(jptr + 32);
664  const __m512 v3 = _mm512_loadu_ps(jptr + 48);
665  const __m512 w1 = _mm512_shuffle_f32x4(force_j_x, force_j_y,
666  0b01000100);
667  const __m512 w2 = _mm512_shuffle_f32x4(force_j_x, force_j_y,
668  0b11101110);
669  __m512 tp0 = _mm512_permutex2var_ps(w1, tp0x, force_j_z);
670  __m512 tp1 = _mm512_permutex2var_ps(w1, tp1x, force_j_z);
671  __m512 tp2 = _mm512_permutex2var_ps(w2, tp2x, force_j_z);
672  __m512 tp3 = _mm512_permutex2var_ps(w2, tp3x, force_j_z);
673  tp0 = _mm512_add_ps(v0, tp0);
674  tp1 = _mm512_add_ps(v1, tp1);
675  tp2 = _mm512_add_ps(v2, tp2);
676  tp3 = _mm512_add_ps(v3, tp3);
677  _mm512_store_ps(jptr, tp0);
678  _mm512_store_ps(jptr + 16, tp1);
679  _mm512_store_ps(jptr + 32, tp2);
680  _mm512_store_ps(jptr + 48, tp3);
681  }
682 
683  if (doSlow) {
684  float * jptr = (float *)(forcesSlow_j + atomStart_j);
685  const __m512 v0 = _mm512_loadu_ps(jptr);
686  const __m512 v1 = _mm512_loadu_ps(jptr + 16);
687  const __m512 v2 = _mm512_loadu_ps(jptr + 32);
688  const __m512 v3 = _mm512_loadu_ps(jptr + 48);
689  const __m512 w1 = _mm512_shuffle_f32x4(forceSlow_j_x, forceSlow_j_y,
690  0b01000100);
691  const __m512 w2 = _mm512_shuffle_f32x4(forceSlow_j_x, forceSlow_j_y,
692  0b11101110);
693  __m512 tp0 = _mm512_permutex2var_ps(w1, tp0x, forceSlow_j_z);
694  __m512 tp1 = _mm512_permutex2var_ps(w1, tp1x, forceSlow_j_z);
695  __m512 tp2 = _mm512_permutex2var_ps(w2, tp2x, forceSlow_j_z);
696  __m512 tp3 = _mm512_permutex2var_ps(w2, tp3x, forceSlow_j_z);
697  tp0 = _mm512_add_ps(v0, tp0);
698  tp1 = _mm512_add_ps(v1, tp1);
699  tp2 = _mm512_add_ps(v2, tp2);
700  tp3 = _mm512_add_ps(v3, tp3);
701  _mm512_store_ps(jptr, tp0);
702  _mm512_store_ps(jptr + 16, tp1);
703  _mm512_store_ps(jptr + 32, tp2);
704  _mm512_store_ps(jptr + 48, tp3);
705  }
706 
707  // ------------- Write exclusions to memory
708  if (doList) {
709  if (jTileActive) {
710  int anyexcl = 65536;
711  if (_mm512_cmpneq_epi32_mask(excl, _mm512_setzero_epi32()))
712  anyexcl |= 1;
713  jTiles.status[jtile] = anyexcl;
714  // Store exclusions
715  _mm512_store_epi32((void *)(jTiles.excl + (jtile << 4)),
716  excl);
717  itileListLen += anyexcl;
718  }
719  } // if doList
720  #ifdef TILE_LIST_STAT_DEBUG
721  if (jtile_empty) num_jtiles_empty++;
722  #endif
723  } // jtile
724 
725  // ------------------------------------------------------
726 
727  // ------------- Accumulate i-forces in memory
728  const __m512i tp0x = _mm512_set_epi32(0,19,11,3,0,18,10,2,
729  0,17,9,1,0,16,8,0);
730  const __m512i tp1x = _mm512_set_epi32(0,23,15,7,0,22,14,6,0,21,
731  13,5,0,20,12,4);
732  const __m512i tp2x = _mm512_set_epi32(0,27,11,3,0,26,10,2,
733  0,25,9,1,0,24,8,0);
734  const __m512i tp3x = _mm512_set_epi32(0,31,15,7,0,30,14,6,
735  0,29,13,5,0,28,12,4);
736  {
737  float * iptr = (float *)(forces_i + atomStart_i);
738  const __m512 v0 = _mm512_loadu_ps(iptr);
739  const __m512 v1 = _mm512_loadu_ps(iptr + 16);
740  const __m512 v2 = _mm512_loadu_ps(iptr + 32);
741  const __m512 v3 = _mm512_loadu_ps(iptr + 48);
742  const __m512 w1 = _mm512_shuffle_f32x4(force_i_x, force_i_y,
743  0b01000100);
744  const __m512 w2 = _mm512_shuffle_f32x4(force_i_x, force_i_y,
745  0b11101110);
746  __m512 tp0 = _mm512_permutex2var_ps(w1, tp0x, force_i_z);
747  __m512 tp1 = _mm512_permutex2var_ps(w1, tp1x, force_i_z);
748  __m512 tp2 = _mm512_permutex2var_ps(w2, tp2x, force_i_z);
749  __m512 tp3 = _mm512_permutex2var_ps(w2, tp3x, force_i_z);
750  tp0 = _mm512_add_ps(v0, tp0);
751  tp1 = _mm512_add_ps(v1, tp1);
752  tp2 = _mm512_add_ps(v2, tp2);
753  tp3 = _mm512_add_ps(v3, tp3);
754  _mm512_store_ps(iptr, tp0);
755  _mm512_store_ps(iptr + 16, tp1);
756  _mm512_store_ps(iptr + 32, tp2);
757  _mm512_store_ps(iptr + 48, tp3);
758  }
759 
760  if (doSlow) {
761  float * iptr = (float *)(forcesSlow_i + atomStart_i);
762  const __m512 v0 = _mm512_loadu_ps(iptr);
763  const __m512 v1 = _mm512_loadu_ps(iptr + 16);
764  const __m512 v2 = _mm512_loadu_ps(iptr + 32);
765  const __m512 v3 = _mm512_loadu_ps(iptr + 48);
766  const __m512 w1 = _mm512_shuffle_f32x4(forceSlow_i_x, forceSlow_i_y,
767  0b01000100);
768  const __m512 w2 = _mm512_shuffle_f32x4(forceSlow_i_x, forceSlow_i_y,
769  0b11101110);
770  __m512 tp0 = _mm512_permutex2var_ps(w1, tp0x, forceSlow_i_z);
771  __m512 tp1 = _mm512_permutex2var_ps(w1, tp1x, forceSlow_i_z);
772  __m512 tp2 = _mm512_permutex2var_ps(w2, tp2x, forceSlow_i_z);
773  __m512 tp3 = _mm512_permutex2var_ps(w2, tp3x, forceSlow_i_z);
774  tp0 = _mm512_add_ps(v0, tp0);
775  tp1 = _mm512_add_ps(v1, tp1);
776  tp2 = _mm512_add_ps(v2, tp2);
777  tp3 = _mm512_add_ps(v3, tp3);
778  _mm512_store_ps(iptr, tp0);
779  _mm512_store_ps(iptr + 16, tp1);
780  _mm512_store_ps(iptr + 32, tp2);
781  _mm512_store_ps(iptr + 48, tp3);
782  }
783 
784  if (doList)
785  listDepth[itileList] = itileListLen;
786 
787  // Update net forces for virial
788  if (doVirial) {
789  // fNet_. += force_i_.
790  fNet_x = _mm512_add_ps(fNet_x, force_i_x);
791  fNet_y = _mm512_add_ps(fNet_y, force_i_y);
792  fNet_z = _mm512_add_ps(fNet_z, force_i_z);
793  if (doSlow) {
794  // fNetSlow_. += forceSlow_i_.
795  fNetSlow_x = _mm512_add_ps(fNetSlow_x, forceSlow_i_x);
796  fNetSlow_y = _mm512_add_ps(fNetSlow_y, forceSlow_i_y);
797  fNetSlow_z = _mm512_add_ps(fNetSlow_z, forceSlow_i_z);
798  }
799  } // if (doVirial)
800 
801  if (NAMD_AVXTILES_PAIR_THRESHOLD > 0 && doList) {
802  _numPairLists = 0;
803  for (int z = 0; z < NAMD_AVXTILES_PAIR_THRESHOLD; z++)
804  if (_numPairs[z]) _numPairLists++;
805  }
806  } // for itileList
807 
808  if (doList) {
809  _numModified = numModified;
810  // WMB: Only needed for slow elec
811  _numExcluded = numExcluded;
812  _exclusionChecksum = numModified + numExcluded;
813  }
814 
815  #ifdef TILE_LIST_STAT_DEBUG
816  if (!doList)
817  printf("TILE_DBG: JTILES %d EMPTY %.2f ROTATES %d EMPTY %.2f EXCL %.2f\n",
818  num_jtiles, double(num_jtiles_empty)/num_jtiles, num_rotates,
819  double(num_rotates_empty)/num_rotates,
820  double(num_rotates_excl_empty)/num_rotates);
821  #endif
822 } // nbAVX512Tiles()
823 
824 //---------------------------------------------------------------------------
825 // Calculations for unmodified/unexcluded from pair lists
826 //---------------------------------------------------------------------------
827 
828 template <bool doEnergy, bool doVirial, bool doSlow, int interpMode>
829 void AVXTileLists::nbAVX512Pairs(__m512 &energyVdw, __m512 &energyElec,
830  __m512 &energySlow, __m512 &fNet_x,
831  __m512 &fNet_y, __m512 &fNet_z,
832  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
833  __m512 &fNetSlow_z) {
834 
835  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_i = tiles_p0->atoms;
836  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_j = tiles_p1->atoms;
837  AVXTiles::AVXTilesForce* __restrict__ forces_i = tiles_p0->forces;
838  AVXTiles::AVXTilesForce* __restrict__ forces_j = tiles_p1->forces;
839  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_i;
840  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_j;
841  if (doSlow) {
842  forcesSlow_i = tiles_p0->forcesSlow;
843  forcesSlow_j = tiles_p1->forcesSlow;
844  }
845 
846  const float* __restrict__ fastTable;
847  const float* __restrict__ energyTable;
848  const float * __restrict__ ljTable;
849  const float * __restrict__ eps4sigma;
850  // Interpolation for long-range splitting function
851  if (interpMode == 3) {
852  fastTable = _paramFastTable;
853  energyTable = _paramFastEnergyTable;
854  }
855  // LJ mixing not performed within kernel
856  if (interpMode > 1) ljTable = _paramLjTable;
857  // LJ mixing performed within kernel
858  if (interpMode == 1) eps4sigma = _paramEps4Sigma;
859 
860  const float* __restrict__ slowTable = _paramSlowTable;
861  const float* __restrict__ slowEnergyTable = _paramSlowEnergyTable;
862 
863  for (int z = 0; z < NAMD_AVXTILES_PAIR_THRESHOLD; z++) {
864  if (!_numPairs[z]) continue;
865 
866  // Zero i-tile force vectors
867  __m512 force_i_x = _mm512_setzero_ps();
868  __m512 force_i_y = _mm512_setzero_ps();
869  __m512 force_i_z = _mm512_setzero_ps();
870  __m512 forceSlow_i_x, forceSlow_i_y, forceSlow_i_z;
871  if (doSlow) {
872  forceSlow_i_x = _mm512_setzero_ps();
873  forceSlow_i_y = _mm512_setzero_ps();
874  forceSlow_i_z = _mm512_setzero_ps();
875  }
876 
877  const int pairI = _pair_i[z];
878  const int nPairs = _numPairs[z];
879  // Load i-atom data
880  const __m512 x_i = _mm512_set1_ps(xyzq_i[pairI].x + _shx);
881  const __m512 y_i = _mm512_set1_ps(xyzq_i[pairI].y + _shy);
882  const __m512 z_i = _mm512_set1_ps(xyzq_i[pairI].z + _shz);
883  const __m512 q_i = _mm512_set1_ps(xyzq_i[pairI].q);
884  const int scalarType_i = tiles_p0->vdwTypes[pairI];
885  const __m512i type_i = _mm512_set1_epi32(scalarType_i);
886  __m512 eps4_i, sigma_i;
887  if (interpMode == 1) {
888  eps4_i = _mm512_set1_ps(eps4sigma[scalarType_i*2]);
889  sigma_i = _mm512_set1_ps(eps4sigma[scalarType_i*2+1]);
890  }
891 
892  __mmask16 loopMask = 0xFFFF;
893  int listPos = _pairStart[z];
894  for (int mv = 0; mv < nPairs; mv += 16) {
895  // Remainder predication
896  if (nPairs - mv < 16)
897  loopMask >>= (16 - (nPairs - mv));
898  // Load j indices from pair list
899  __m512i j = _mm512_loadu_epi32(_pairLists + listPos);
900  listPos += 16;
901 
902  // Load j atom data
903  const __m512i jt2 = _mm512_slli_epi32(j, 1);
904  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(
905  _mm512_undefined_pd(), loopMask, jt2, (float *)xyzq_j, _MM_SCALE_8);
906  const __m512 t8 = (__m512)_mm512_mask_i32logather_pd(
907  _mm512_undefined_pd(), loopMask, jt2, (float *)(xyzq_j) + 2,
908  _MM_SCALE_8);
909  const __m512i jt2_2 = _mm512_shuffle_i32x4(jt2, jt2, 238);
910  const __mmask16 loopMask2 = loopMask >> 8;
911  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(
912  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)xyzq_j, _MM_SCALE_8);
913  const __m512 t9 = (__m512)_mm512_mask_i32logather_pd(
914  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)(xyzq_j) + 2,
915  _MM_SCALE_8);
916  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
917  15,13,11,9,7,5,3,1);
918  const __m512 y_j = _mm512_permutex2var_ps(t0, t4, t1);
919  const __m512 q_j = _mm512_permutex2var_ps(t8, t4, t9);
920  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
921  14,12,10,8,6,4,2,0);
922  const __m512 x_j = _mm512_permutex2var_ps(t0, t6, t1);
923  const __m512 z_j = _mm512_permutex2var_ps(t8, t6, t9);
924 
925  // kqq = q_i * q_j
926  const __m512 kqq = _mm512_mul_ps(q_i, q_j);
927 
928  // dx = x_i - x_j;
929  const __m512 dx = _mm512_sub_ps(x_i, x_j);
930  // dy = y_i - y_j;
931  const __m512 dy = _mm512_sub_ps(y_i, y_j);
932  // dz = z_i - z_j;
933  const __m512 dz = _mm512_sub_ps(z_i, z_j);
934  // r2 = dx*dx + dy*dy + dz*dz;
935  const __m512 r2(_mm512_fmadd_ps(dx,dx,_mm512_fmadd_ps(dy,dy,
936  _mm512_mul_ps(dz, dz))));
937  // Atoms within cutoff
938  const __mmask16 r2mask = _mm512_cmple_ps_mask(r2,
939  _mm512_set1_ps(_cutoff2)) & loopMask;
940 
941  // Load LJ types
942  const __m512i type_j = _mm512_mask_i32gather_epi32(type_j, r2mask, j,
943  tiles_p1->vdwTypes, _MM_SCALE_4);
944 
945  // Load eps and sigma
946  __m512 eps4_j, sigma_j;
947  if (interpMode == 1) {
948  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(
949  _mm512_undefined_pd(), r2mask, type_j, eps4sigma, _MM_SCALE_8);
950  const __m512i type_j2 = _mm512_shuffle_i32x4(type_j, type_j, 238);
951  const __mmask16 r2mask2 = r2mask >> 8;
952  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(
953  _mm512_undefined_pd(), r2mask2, type_j2, eps4sigma, _MM_SCALE_8);
954  sigma_j = _mm512_permutex2var_ps(t0, t4, t1);
955  eps4_j = _mm512_permutex2var_ps(t0, t6, t1);
956  }
957 
958  // Force, Energy, Virial calculation
959  __m512 force, forceSlow;
960  if (interpMode == 1)
961  forceEnergyInterp1<doEnergy, doSlow>(r2, kqq, force, forceSlow,
962  energyVdw, energyElec, energySlow, r2mask, _paramC1, _paramC3,
963  _paramSwitchOn2, _cutoff2, _paramMinvCut3, _paramCutUnder3,
964  slowTable, slowEnergyTable, eps4_i, eps4_j, sigma_i, sigma_j);
965  else
966  forceEnergyInterp2<doEnergy, doSlow, interpMode>(r2, kqq, type_i,
967  type_j, force, forceSlow, energyVdw, energyElec, energySlow,
968  r2mask, _paramScale, _paramC1, _paramC3, _paramSwitchOn2,
969  _cutoff2, _paramMinvCut3, _paramCutUnder3, fastTable, energyTable,
970  slowTable, slowEnergyTable, ljTable, _paramLjWidth);
971 
972  // force_. = d. * force
973  const __m512 force_x = _mm512_mul_ps(dx, force);
974  const __m512 force_y = _mm512_mul_ps(dy, force);
975  const __m512 force_z = _mm512_mul_ps(dz, force);
976  // Accumulate j forces in memory
977  const __m512i j4 = _mm512_slli_epi32(j, 2);
978  __m512 ft0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
979  r2mask, j4, (float*)forces_j, _MM_SCALE_4);
980  const __m512i j4_2 = _mm512_shuffle_i32x4(j4, j4, 238);
981  const __mmask16 r2mask2 = r2mask >> 8;
982  __m512 ft1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
983  r2mask2, j4_2, (float *)forces_j, _MM_SCALE_4);
984  const __m512i ft4 = _mm512_set_epi32(23,7,22,6,21,5,20,4,
985  19,3,18,2,17,1,16,0);
986  const __m512 ft5 = _mm512_permutex2var_ps(force_x, ft4, force_y);
987  ft0 = _mm512_add_ps(ft0, ft5);
988  _mm512_mask_i32loscatter_pd((void *)forces_j, r2mask, j4, (__m512d)ft0,
989  _MM_SCALE_4);
990  const __m512i ft2 = _mm512_set_epi32(31,15,30,14,29,13,28,12,
991  27,11,26,10,25,9,24,8);
992  const __m512 ft3 = _mm512_permutex2var_ps(force_x, ft2, force_y);
993  ft1 = _mm512_add_ps(ft1, ft3);
994  _mm512_mask_i32loscatter_pd((void *)forces_j, r2mask2, j4_2,
995  (__m512d)ft1, _MM_SCALE_4);
996  __m512 mem3 = _mm512_mask_i32gather_ps(_mm512_undefined_ps(), r2mask,
997  j4, (float *)(forces_j)+2,
998  _MM_SCALE_4);
999  mem3 = _mm512_add_ps(mem3, force_z);
1000  _mm512_mask_i32scatter_ps((float *)(forces_j)+2, r2mask, j4, mem3,
1001  _MM_SCALE_4);
1002 
1003  // force_i_. -= force_.
1004  force_i_x = _mm512_mask_sub_ps(force_i_x, r2mask, force_i_x, force_x);
1005  force_i_y = _mm512_mask_sub_ps(force_i_y, r2mask, force_i_y, force_y);
1006  force_i_z = _mm512_mask_sub_ps(force_i_z, r2mask, force_i_z, force_z);
1007 
1008  __m512 forceSlow_x, forceSlow_y, forceSlow_z;
1009  if (doSlow) {
1010  // forceSlow_. = d. * forceSlow
1011  forceSlow_x = _mm512_mul_ps(dx, forceSlow);
1012  forceSlow_y = _mm512_mul_ps(dy, forceSlow);
1013  forceSlow_z = _mm512_mul_ps(dz, forceSlow);
1014  // Accumulate j slow forces in memory
1015  // acc3(r2mask, (float*)forcesSlow_j, j, forceSlow_x,
1016  // forceSlow_y, forceSlow_z);
1017  __m512 ft10 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
1018  r2mask, j4, (float*)forcesSlow_j, _MM_SCALE_4);
1019  __m512 ft11 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
1020  r2mask2, j4_2, (float *)forcesSlow_j, _MM_SCALE_4);
1021  const __m512 ft15 = _mm512_permutex2var_ps(forceSlow_x, ft4,
1022  forceSlow_y);
1023  ft10 = _mm512_add_ps(ft10, ft15);
1024  _mm512_mask_i32loscatter_pd((void *)forcesSlow_j, r2mask, j4,
1025  (__m512d)ft10, _MM_SCALE_4);
1026  const __m512 ft13 = _mm512_permutex2var_ps(forceSlow_x, ft2,
1027  forceSlow_y);
1028  ft11 = _mm512_add_ps(ft11, ft13);
1029  _mm512_mask_i32loscatter_pd((void *)forcesSlow_j, r2mask2, j4_2,
1030  (__m512d)ft11, _MM_SCALE_4);
1031  __m512 mem13 = _mm512_mask_i32gather_ps(_mm512_undefined_ps(), r2mask,
1032  j4, (float *)(forcesSlow_j)+2,
1033  _MM_SCALE_4);
1034  mem13 = _mm512_add_ps(mem13, forceSlow_z);
1035  _mm512_mask_i32scatter_ps((float *)(forcesSlow_j)+2, r2mask, j4, mem13,
1036  _MM_SCALE_4);
1037 
1038  // forceSlow_i_. -= forceSlow_.
1039  forceSlow_i_x = _mm512_mask_sub_ps(forceSlow_i_x, r2mask,
1040  forceSlow_i_x, forceSlow_x);
1041  forceSlow_i_y = _mm512_mask_sub_ps(forceSlow_i_y, r2mask,
1042  forceSlow_i_y, forceSlow_y);
1043  forceSlow_i_z = _mm512_mask_sub_ps(forceSlow_i_z, r2mask,
1044  forceSlow_i_z, forceSlow_z);
1045  }
1046  } // for mv
1047 
1048  // Reduction on i vectors, accumulate in memory, update fNet_. for virial
1049  float fI = _mm512_reduce_add_ps(force_i_x);
1050  forces_i[pairI].x += fI;
1051  if (doVirial) *((float*)&fNet_x) += fI;
1052  fI = _mm512_reduce_add_ps(force_i_y);
1053  forces_i[pairI].y += fI;
1054  if (doVirial) *((float*)&fNet_y) += fI;
1055  fI = _mm512_reduce_add_ps(force_i_z);
1056  forces_i[pairI].z += fI;
1057  if (doVirial) *((float*)&fNet_z) += fI;
1058  if (doSlow) {
1059  fI = _mm512_reduce_add_ps(forceSlow_i_x);
1060  forcesSlow_i[pairI].x += fI;
1061  if (doVirial) *((float*)&fNetSlow_x) += fI;
1062  fI = _mm512_reduce_add_ps(forceSlow_i_y);
1063  forcesSlow_i[pairI].y += fI;
1064  if (doVirial) *((float*)&fNetSlow_y) += fI;
1065  fI = _mm512_reduce_add_ps(forceSlow_i_z);
1066  forcesSlow_i[pairI].z += fI;
1067  if (doVirial) *((float*)&fNetSlow_z) += fI;
1068  }
1069  } // for z
1070 } // nbAVX512Pairs()
1071 
1072 //---------------------------------------------------------------------------
1073 // Calculations for modified pairs
1074 //---------------------------------------------------------------------------
1075 
1076 template <bool doEnergy, bool doVirial, bool doSlow, int interpMode>
1077 void AVXTileLists::nbAVX512Modified(__m512 &energyVdw, __m512 &energyElec,
1078  __m512 &energySlow, __m512 &fNet_x,
1079  __m512 &fNet_y, __m512 &fNet_z,
1080  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
1081  __m512 &fNetSlow_z) {
1082 
1083  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_i = tiles_p0->atoms;
1084  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_j = tiles_p1->atoms;
1085  AVXTiles::AVXTilesForce* __restrict__ forces_i = tiles_p0->forces;
1086  AVXTiles::AVXTilesForce* __restrict__ forces_j = tiles_p1->forces;
1087  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_i;
1088  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_j;
1089  if (doSlow) {
1090  forcesSlow_i = tiles_p0->forcesSlow;
1091  forcesSlow_j = tiles_p1->forcesSlow;
1092  }
1093 
1094  const float* __restrict__ fastTable;
1095  const float* __restrict__ energyTable;
1096  // Interpolation for long-range splitting function
1097  if (interpMode == 3) {
1098  fastTable = _paramFastTable;
1099  energyTable = _paramFastEnergyTable;
1100  }
1101 
1102  const float * __restrict__ mod_table = _paramModifiedTable;
1103  const float * __restrict__ mode_table = _paramModifiedEnergyTable;
1104 
1105  const float * __restrict__ ljTable14;
1106  const float * __restrict__ eps4sigma14;
1107  if (interpMode > 1)
1108  ljTable14 = _paramLjTable + 2;
1109  else
1110  eps4sigma14 = _paramEps4Sigma14;
1111 
1112  __mmask16 loopMask = 0xFFFF;
1113  #pragma novector
1114  for (int mv = 0; mv < _numModified; mv += 16) {
1115  // Remainder predication
1116  if (_numModified - mv < 16)
1117  loopMask >>= (16 - (_numModified - mv));
1118  // Load i and j indices for pairs on modified list
1119  const __m512i i = _mm512_loadu_epi32(_modified_i + mv);
1120  const __m512i j = _mm512_loadu_epi32(_modified_j + mv);
1121 
1122  // Load i atom data and shift coordinates
1123  const __m512i it2 = _mm512_slli_epi32(i, 1);
1124  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(
1125  _mm512_undefined_pd(), loopMask, it2, (float *)xyzq_i, _MM_SCALE_8);
1126  const __m512 t8 = (__m512)_mm512_mask_i32logather_pd(
1127  _mm512_undefined_pd(), loopMask, it2, (float *)(xyzq_i) + 2,
1128  _MM_SCALE_8);
1129  const __m512i it2_2 = _mm512_shuffle_i32x4(it2, it2, 238);
1130  const __mmask16 loopMask2 = loopMask >> 8;
1131  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(
1132  _mm512_undefined_pd(), loopMask2, it2_2, (float *)xyzq_i, _MM_SCALE_8);
1133  const __m512 t9 = (__m512)_mm512_mask_i32logather_pd(
1134  _mm512_undefined_pd(), loopMask2, it2_2, (float *)(xyzq_i) + 2,
1135  _MM_SCALE_8);
1136  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
1137  15,13,11,9,7,5,3,1);
1138  const __m512 y_i = _mm512_add_ps(_mm512_permutex2var_ps(t0, t4, t1),
1139  _mm512_set1_ps(_shy));
1140  const __m512 q_i = _mm512_permutex2var_ps(t8, t4, t9);
1141  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
1142  14,12,10,8,6,4,2,0);
1143  const __m512 x_i = _mm512_add_ps(_mm512_permutex2var_ps(t0, t6, t1),
1144  _mm512_set1_ps(_shx));
1145  const __m512 z_i = _mm512_add_ps(_mm512_permutex2var_ps(t8, t6, t9),
1146  _mm512_set1_ps(_shz));
1147 
1148  // Load j atom data
1149  const __m512i jt2 = _mm512_slli_epi32(j, 1);
1150  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
1151  _mm512_undefined_pd(), loopMask, jt2, (float *)xyzq_j, _MM_SCALE_8);
1152  const __m512 t18 = (__m512)_mm512_mask_i32logather_pd(
1153  _mm512_undefined_pd(), loopMask, jt2, (float *)(xyzq_j) + 2,
1154  _MM_SCALE_8);
1155  const __m512i jt2_2 = _mm512_shuffle_i32x4(jt2, jt2, 238);
1156  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
1157  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)xyzq_j, _MM_SCALE_8);
1158  const __m512 t19 = (__m512)_mm512_mask_i32logather_pd(
1159  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)(xyzq_j) + 2,
1160  _MM_SCALE_8);
1161  const __m512 y_j = _mm512_permutex2var_ps(t10, t4, t11);
1162  const __m512 q_j = _mm512_permutex2var_ps(t18, t4, t19);
1163  const __m512 x_j = _mm512_permutex2var_ps(t10, t6, t11);
1164  const __m512 z_j = _mm512_permutex2var_ps(t18, t6, t19);
1165 
1166  // kqq = q_i * q_j * _paramScale14
1167  const __m512 kqq = _mm512_mul_ps(q_i, _mm512_mul_ps(q_j,
1168  _mm512_set1_ps(_paramScale14)));
1169 
1170  // dx = x_i - x_j;
1171  const __m512 dx = _mm512_sub_ps(x_i, x_j);
1172  // dy = y_i - y_j;
1173  const __m512 dy = _mm512_sub_ps(y_i, y_j);
1174  // dz = z_i - z_j;
1175  const __m512 dz = _mm512_sub_ps(z_i, z_j);
1176  // r2 = dx*dx + dy*dy + dz*dz;
1177  const __m512 r2(_mm512_fmadd_ps(dx,dx,_mm512_fmadd_ps(dy,dy,
1178  _mm512_mul_ps(dz, dz))));
1179  // Atoms within cutoff
1180  __mmask16 r2mask = _mm512_cmple_ps_mask(r2, _mm512_set1_ps(_cutoff2)) &
1181  loopMask;
1182 
1183  // Load LJ types
1184  const __m512i type_i = _mm512_mask_i32gather_epi32(type_i, r2mask, i,
1185  tiles_p0->vdwTypes, _MM_SCALE_4);
1186  const __m512i type_j = _mm512_mask_i32gather_epi32(type_j, r2mask, j,
1187  tiles_p1->vdwTypes, _MM_SCALE_4);
1188 
1189  // Load eps and sigma
1190  __m512 eps4_i14, sigma_i14, eps4_j14, sigma_j14;
1191  if (interpMode == 1) {
1192  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(
1193  _mm512_undefined_pd(), r2mask, type_i, eps4sigma14, _MM_SCALE_8);
1194  const __m512i type_i2 = _mm512_shuffle_i32x4(type_i, type_i, 238);
1195  const __mmask16 r2mask2 = r2mask >> 8;
1196  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(
1197  _mm512_undefined_pd(), r2mask2, type_i2, eps4sigma14, _MM_SCALE_8);
1198  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
1199  _mm512_undefined_pd(), r2mask, type_j, eps4sigma14, _MM_SCALE_8);
1200  const __m512i type_j2 = _mm512_shuffle_i32x4(type_j, type_j, 238);
1201  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
1202  _mm512_undefined_pd(), r2mask2, type_j2, eps4sigma14, _MM_SCALE_8);
1203 
1204  sigma_i14 = _mm512_permutex2var_ps(t0, t4, t1);
1205  sigma_j14 = _mm512_permutex2var_ps(t10, t4, t11);
1206  eps4_i14 = _mm512_permutex2var_ps(t0, t6, t1);
1207  eps4_j14 = _mm512_permutex2var_ps(t10, t6, t11);
1208  }
1209 
1210  // Force, Energy, Virial calculation
1211  __m512 force, forceSlow;
1212  if (interpMode == 1)
1213  forceEnergyInterp1<doEnergy, doSlow>(r2, kqq, force, forceSlow,
1214  energyVdw, energyElec, energySlow, r2mask, _paramC1, _paramC3,
1215  _paramSwitchOn2, _cutoff2, _paramMinvCut3, _paramCutUnder3,
1216  mod_table, mode_table, eps4_i14, eps4_j14, sigma_i14, sigma_j14);
1217  else
1218  forceEnergyInterp2<doEnergy, doSlow, interpMode>(r2, kqq, type_i, type_j,
1219  force, forceSlow, energyVdw, energyElec, energySlow, r2mask,
1220  _paramScale, _paramC1, _paramC3, _paramSwitchOn2, _cutoff2,
1221  _paramMinvCut3, _paramCutUnder3, fastTable, energyTable,
1222  mod_table, mode_table, ljTable14, _paramLjWidth);
1223 
1224  // force_i_. = d. * force
1225  const __m512 force_i_x = _mm512_mul_ps(dx, force);
1226  const __m512 force_i_y = _mm512_mul_ps(dy, force);
1227  const __m512 force_i_z = _mm512_mul_ps(dz, force);
1228  __m512 forceSlow_i_x, forceSlow_i_y, forceSlow_i_z;
1229  if (doSlow) {
1230  // forceSlow_i_. = d. * forceSlow
1231  forceSlow_i_x = _mm512_mul_ps(dx, forceSlow);
1232  forceSlow_i_y = _mm512_mul_ps(dy, forceSlow);
1233  forceSlow_i_z = _mm512_mul_ps(dz, forceSlow);
1234  }
1235 
1236  if (doVirial) {
1237  // fNet_. -= force_i_.
1238  fNet_x = _mm512_mask_sub_ps(fNet_x, r2mask, fNet_x, force_i_x);
1239  fNet_y = _mm512_mask_sub_ps(fNet_y, r2mask, fNet_y, force_i_y);
1240  fNet_z = _mm512_mask_sub_ps(fNet_z, r2mask, fNet_z, force_i_z);
1241  if (doSlow) {
1242  // fNetSlow_. -= forceSlow_i_.
1243  fNetSlow_x = _mm512_mask_sub_ps(fNetSlow_x, r2mask,
1244  fNetSlow_x, forceSlow_i_x);
1245  fNetSlow_y = _mm512_mask_sub_ps(fNetSlow_y, r2mask,
1246  fNetSlow_y, forceSlow_i_y);
1247  fNetSlow_z = _mm512_mask_sub_ps(fNetSlow_z, r2mask,
1248  fNetSlow_z, forceSlow_i_z);
1249  }
1250  }
1251 
1252  #pragma novector
1253  for (int z = 0; z < 16; z++) {
1254  // Skip if outside cutoff or remainder
1255  if (!(r2mask & 1)) {
1256  r2mask >>= 1;
1257  continue;
1258  }
1259  r2mask >>= 1;
1260 
1261  // WMB: Might be better to check if next i is same and update in mem once
1262 
1263  // Accumulate i and j forces in memory
1264  const int i_z = *((int*)&i + z);
1265  const int j_z = *((int*)&j + z);
1266  forces_i[i_z].x -= *((float*)&force_i_x + z);
1267  forces_i[i_z].y -= *((float*)&force_i_y + z);
1268  forces_i[i_z].z -= *((float*)&force_i_z + z);
1269  forces_j[j_z].x += *((float*)&force_i_x + z);
1270  forces_j[j_z].y += *((float*)&force_i_y + z);
1271  forces_j[j_z].z += *((float*)&force_i_z + z);
1272  if (doSlow) {
1273  forcesSlow_i[i_z].x -= *((float*)&forceSlow_i_x + z);
1274  forcesSlow_i[i_z].y -= *((float*)&forceSlow_i_y + z);
1275  forcesSlow_i[i_z].z -= *((float*)&forceSlow_i_z + z);
1276  forcesSlow_j[j_z].x += *((float*)&forceSlow_i_x + z);
1277  forcesSlow_j[j_z].y += *((float*)&forceSlow_i_y + z);
1278  forcesSlow_j[j_z].z += *((float*)&forceSlow_i_z + z);
1279  }
1280  }
1281  }
1282 } // nbAVX512Modified()
1283 
1284 //---------------------------------------------------------------------------
1285 // Calculations for excluded pairs
1286 //---------------------------------------------------------------------------
1287 
1288 template <bool doEnergy, bool doVirial>
1289 void AVXTileLists::nbAVX512Excluded(__m512 &energySlow, __m512 &fNetSlow_x,
1290  __m512 &fNetSlow_y, __m512 &fNetSlow_z) {
1291 
1292  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_i = tiles_p0->atoms;
1293  const AVXTiles::AVXTilesAtom* __restrict__ xyzq_j = tiles_p1->atoms;
1294  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_i = tiles_p0->forcesSlow;
1295  AVXTiles::AVXTilesForce* __restrict__ forcesSlow_j = tiles_p1->forcesSlow;
1296 
1297  const float * __restrict__ exclTable = _paramExcludedTable;
1298  const float * __restrict__ exclEtable = _paramExcludedEnergyTable;
1299 
1300  __mmask16 loopMask = 0xFFFF;
1301  for (int mv = 0; mv < _numExcluded; mv += 16) {
1302  // Remainder predication
1303  if (_numExcluded - mv < 16)
1304  loopMask >>= (16 - (_numExcluded - mv));
1305  // Load i and j indices for pairs on modified list
1306  const __m512i i = _mm512_loadu_epi32(_excluded_i + mv);
1307  const __m512i j = _mm512_loadu_epi32(_excluded_j + mv);
1308 
1309  // Load i atom data and shift coordinates
1310  const __m512i it2 = _mm512_slli_epi32(i, 1);
1311  const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(
1312  _mm512_undefined_pd(), loopMask, it2, (float *)xyzq_i, _MM_SCALE_8);
1313  const __m512 t8 = (__m512)_mm512_mask_i32logather_pd(
1314  _mm512_undefined_pd(), loopMask, it2, (float *)(xyzq_i) + 2,
1315  _MM_SCALE_8);
1316  const __m512i it2_2 = _mm512_shuffle_i32x4(it2, it2, 238);
1317  const __mmask16 loopMask2 = loopMask >> 8;
1318  const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(
1319  _mm512_undefined_pd(), loopMask2, it2_2, (float *)xyzq_i, _MM_SCALE_8);
1320  const __m512 t9 = (__m512)_mm512_mask_i32logather_pd(
1321  _mm512_undefined_pd(), loopMask2, it2_2, (float *)(xyzq_i) + 2,
1322  _MM_SCALE_8);
1323  const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
1324  15,13,11,9,7,5,3,1);
1325  const __m512 y_i = _mm512_add_ps(_mm512_permutex2var_ps(t0, t4, t1),
1326  _mm512_set1_ps(_shy));
1327  const __m512 q_i = _mm512_permutex2var_ps(t8, t4, t9);
1328  const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
1329  14,12,10,8,6,4,2,0);
1330  const __m512 x_i = _mm512_add_ps(_mm512_permutex2var_ps(t0, t6, t1),
1331  _mm512_set1_ps(_shx));
1332  const __m512 z_i = _mm512_add_ps(_mm512_permutex2var_ps(t8, t6, t9),
1333  _mm512_set1_ps(_shz));
1334 
1335  // Load j atom data
1336  const __m512i jt2 = _mm512_slli_epi32(j, 1);
1337  const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
1338  _mm512_undefined_pd(), loopMask, jt2, (float *)xyzq_j, _MM_SCALE_8);
1339  const __m512 t18 = (__m512)_mm512_mask_i32logather_pd(
1340  _mm512_undefined_pd(), loopMask, jt2, (float *)(xyzq_j) + 2,
1341  _MM_SCALE_8);
1342  const __m512i jt2_2 = _mm512_shuffle_i32x4(jt2, jt2, 238);
1343  const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
1344  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)xyzq_j, _MM_SCALE_8);
1345  const __m512 t19 = (__m512)_mm512_mask_i32logather_pd(
1346  _mm512_undefined_pd(), loopMask2, jt2_2, (float *)(xyzq_j) + 2,
1347  _MM_SCALE_8);
1348  const __m512 y_j = _mm512_permutex2var_ps(t10, t4, t11);
1349  const __m512 q_j = _mm512_permutex2var_ps(t18, t4, t19);
1350  const __m512 x_j = _mm512_permutex2var_ps(t10, t6, t11);
1351  const __m512 z_j = _mm512_permutex2var_ps(t18, t6, t19);
1352 
1353  // kqq = q_i * q_j
1354  const __m512 kqq = _mm512_mul_ps(q_i, q_j);
1355 
1356  // dx = x_i - x_j;
1357  const __m512 dx = _mm512_sub_ps(x_i, x_j);
1358  // dy = y_i - y_j;
1359  const __m512 dy = _mm512_sub_ps(y_i, y_j);
1360  // dz = z_i - z_j;
1361  const __m512 dz = _mm512_sub_ps(z_i, z_j);
1362  // r2 = dx*dx + dy*dy + dz*dz;
1363  const __m512 r2(_mm512_fmadd_ps(dx,dx,_mm512_fmadd_ps(dy,dy,
1364  _mm512_mul_ps(dz, dz))));
1365  // Atoms within cutoff
1366  __mmask16 r2mask = _mm512_cmple_ps_mask(r2, _mm512_set1_ps(_cutoff2)) &
1367  loopMask;
1368 
1369  // Force, Energy, Virial calculation
1370  const __m512 r_1 = _mm512_invsqrt_ps(r2);
1371  __m512 forceSlow, tableDiff, rTableDiff;
1372  __m512i table_int;
1373  getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
1374  forceEnergySlow512<doEnergy>(r2mask, kqq, exclTable, exclEtable,
1375  table_int, tableDiff, rTableDiff, forceSlow,
1376  energySlow);
1377 
1378  // forceSlow_i_. = d. * forceSlow
1379  const __m512 forceSlow_i_x = _mm512_mul_ps(dx, forceSlow);
1380  const __m512 forceSlow_i_y = _mm512_mul_ps(dy, forceSlow);
1381  const __m512 forceSlow_i_z = _mm512_mul_ps(dz, forceSlow);
1382  // fNetSlow_. -= forceSlow_i_.
1383  fNetSlow_x = _mm512_mask_sub_ps(fNetSlow_x, r2mask,
1384  fNetSlow_x, forceSlow_i_x);
1385  fNetSlow_y = _mm512_mask_sub_ps(fNetSlow_y, r2mask,
1386  fNetSlow_y, forceSlow_i_y);
1387  fNetSlow_z = _mm512_mask_sub_ps(fNetSlow_z, r2mask,
1388  fNetSlow_z, forceSlow_i_z);
1389 
1390  for (int z = 0; z < 16; z++) {
1391  // Skip if outside cutoff or remainder
1392  if (!(r2mask & 1)) {
1393  r2mask >>= 1;
1394  continue;
1395  }
1396  r2mask >>= 1;
1397 
1398  // WMB: Might be better to check if next i is same and update in mem once
1399  const int i_z = *((int*)&i + z);
1400  const int j_z = *((int*)&j + z);
1401  forcesSlow_i[i_z].x -= *((float*)&forceSlow_i_x + z);
1402  forcesSlow_i[i_z].y -= *((float*)&forceSlow_i_y + z);
1403  forcesSlow_i[i_z].z -= *((float*)&forceSlow_i_z + z);
1404  forcesSlow_j[j_z].x += *((float*)&forceSlow_i_x + z);
1405  forcesSlow_j[j_z].y += *((float*)&forceSlow_i_y + z);
1406  forcesSlow_j[j_z].z += *((float*)&forceSlow_i_z + z);
1407  } // for m
1408  }
1409 }
1410 
1411 template <bool doEnergy, bool doVirial, bool doSlow, bool doList,
1412  int interpMode>
1413 void AVXTileLists::doAll() {
1414 
1415  if (doList) build();
1416 
1417  __m512 energyVdw, energyElec, energySlow;
1418  __m512 fNet_x, fNet_y, fNet_z;
1419  __m512 fNetSlow_x, fNetSlow_y, fNetSlow_z;
1420 
1421  // Zero energy and virial vectors used in all routines
1422  if (doEnergy) {
1423  energyVdw = _mm512_setzero_ps();
1424  energyElec = _mm512_setzero_ps();
1425  if (doSlow) energySlow = _mm512_setzero_ps();
1426  }
1427  if (doVirial) {
1428  fNet_x = _mm512_setzero_ps();
1429  fNet_y = _mm512_setzero_ps();
1430  fNet_z = _mm512_setzero_ps();
1431  if (doSlow) {
1432  fNetSlow_x = _mm512_setzero_ps();
1433  fNetSlow_y = _mm512_setzero_ps();
1434  fNetSlow_z = _mm512_setzero_ps();
1435  }
1436  }
1437 
1438  // Calculations from tile lists
1439  nbAVX512Tiles<doEnergy, doVirial, doSlow, doList,
1440  interpMode>(energyVdw, energyElec, energySlow, fNet_x, fNet_y,
1441  fNet_z, fNetSlow_x, fNetSlow_y, fNetSlow_z);
1442 
1443  if (doList) delEmptyLists();
1444 
1445  // Calculations from pair lists
1446  if (NAMD_AVXTILES_PAIR_THRESHOLD > 0 && _numPairLists)
1447  nbAVX512Pairs<doEnergy, doVirial, doSlow,
1448  interpMode>(energyVdw, energyElec, energySlow, fNet_x,
1449  fNet_y, fNet_z, fNetSlow_x, fNetSlow_y,
1450  fNetSlow_z);
1451 
1452  // Calculations for modified exclusions
1453  if (_numModified)
1454  nbAVX512Modified<doEnergy, doVirial, doSlow,
1455  interpMode>(energyVdw, energyElec, energySlow, fNet_x,
1456  fNet_y, fNet_z, fNetSlow_x, fNetSlow_y,
1457  fNetSlow_z);
1458 
1459  // Calculations for full exclusions
1460  if (doSlow && _numExcluded)
1461  nbAVX512Excluded<doEnergy, doVirial>(energySlow, fNetSlow_x,
1462  fNetSlow_y, fNetSlow_z);
1463 
1464  // Reduce energy vectors into scalar
1465  if (doEnergy) {
1466  _energyVdw = _mm512_reduce_add_ps(energyVdw);
1467  if (interpMode > 1) _energyVdw *= _paramScale * 0.5f;
1468  _energyElec = _mm512_reduce_add_ps(energyElec);
1469  if (doSlow) _energySlow = _mm512_reduce_add_ps(energySlow);
1470  }
1471 
1472  // Reduce virial vectors into scalars
1473  if (doVirial) {
1474  _fNet_x = _mm512_reduce_add_ps(fNet_x);
1475  _fNet_y = _mm512_reduce_add_ps(fNet_y);
1476  _fNet_z = _mm512_reduce_add_ps(fNet_z);
1477  if (doSlow) {
1478  _fNetSlow_x = _mm512_reduce_add_ps(fNetSlow_x);
1479  _fNetSlow_y = _mm512_reduce_add_ps(fNetSlow_y);
1480  _fNetSlow_z = _mm512_reduce_add_ps(fNetSlow_z);
1481  }
1482  }
1483 
1484  // Touch "tiles" data structures on patches used to indicate they
1485  // need reduction with native NAMD data and zeroing for subsequent use
1486  if (doList) {
1487  if (_numLists || _numModified || _numExcluded ||
1488  (NAMD_AVXTILES_PAIR_THRESHOLD > 0 && _numPairLists)) {
1489  tiles_p0->touch();
1490  tiles_p1->touch();
1491  }
1492  } else {
1493  tiles_p0->touch();
1494  tiles_p1->touch();
1495  }
1496 }
1497 
1498 //---------------------------------------------------------------------------
1499 
1500 void AVXTileLists::nbForceAVX512(const int doEnergy, const int doVirial,
1501  const int doSlow, const int doList) {
1502 
1503 #define CALL(DOENERGY, DOVIRIAL, DOSLOW, DOPAIRLIST, INTERPMODE) \
1504  doAll<DOENERGY, DOVIRIAL, DOSLOW, DOPAIRLIST, INTERPMODE>();
1505 
1506  if (_interpolationMode == 1) {
1507  if (!doEnergy && !doVirial && !doSlow && !doList) CALL(0, 0, 0, 0, 1);
1508  if (!doEnergy && !doVirial && doSlow && !doList) CALL(0, 0, 1, 0, 1);
1509  if (!doEnergy && doVirial && !doSlow && !doList) CALL(0, 1, 0, 0, 1);
1510  if (!doEnergy && doVirial && doSlow && !doList) CALL(0, 1, 1, 0, 1);
1511  if ( doEnergy && !doVirial && !doSlow && !doList) CALL(1, 0, 0, 0, 1);
1512  if ( doEnergy && !doVirial && doSlow && !doList) CALL(1, 0, 1, 0, 1);
1513  if ( doEnergy && doVirial && !doSlow && !doList) CALL(1, 1, 0, 0, 1);
1514  if ( doEnergy && doVirial && doSlow && !doList) CALL(1, 1, 1, 0, 1);
1515 
1516  if (!doEnergy && !doVirial && !doSlow && doList) CALL(0, 0, 0, 1, 1);
1517  if (!doEnergy && !doVirial && doSlow && doList) CALL(0, 0, 1, 1, 1);
1518  if (!doEnergy && doVirial && !doSlow && doList) CALL(0, 1, 0, 1, 1);
1519  if (!doEnergy && doVirial && doSlow && doList) CALL(0, 1, 1, 1, 1);
1520  if ( doEnergy && !doVirial && !doSlow && doList) CALL(1, 0, 0, 1, 1);
1521  if ( doEnergy && !doVirial && doSlow && doList) CALL(1, 0, 1, 1, 1);
1522  if ( doEnergy && doVirial && !doSlow && doList) CALL(1, 1, 0, 1, 1);
1523  if ( doEnergy && doVirial && doSlow && doList) CALL(1, 1, 1, 1, 1);
1524  } else if (_interpolationMode == 2) {
1525  if (!doEnergy && !doVirial && !doSlow && !doList) CALL(0, 0, 0, 0, 2);
1526  if (!doEnergy && !doVirial && doSlow && !doList) CALL(0, 0, 1, 0, 2);
1527  if (!doEnergy && doVirial && !doSlow && !doList) CALL(0, 1, 0, 0, 2);
1528  if (!doEnergy && doVirial && doSlow && !doList) CALL(0, 1, 1, 0, 2);
1529  if ( doEnergy && !doVirial && !doSlow && !doList) CALL(1, 0, 0, 0, 2);
1530  if ( doEnergy && !doVirial && doSlow && !doList) CALL(1, 0, 1, 0, 2);
1531  if ( doEnergy && doVirial && !doSlow && !doList) CALL(1, 1, 0, 0, 2);
1532  if ( doEnergy && doVirial && doSlow && !doList) CALL(1, 1, 1, 0, 2);
1533 
1534  if (!doEnergy && !doVirial && !doSlow && doList) CALL(0, 0, 0, 1, 2);
1535  if (!doEnergy && !doVirial && doSlow && doList) CALL(0, 0, 1, 1, 2);
1536  if (!doEnergy && doVirial && !doSlow && doList) CALL(0, 1, 0, 1, 2);
1537  if (!doEnergy && doVirial && doSlow && doList) CALL(0, 1, 1, 1, 2);
1538  if ( doEnergy && !doVirial && !doSlow && doList) CALL(1, 0, 0, 1, 2);
1539  if ( doEnergy && !doVirial && doSlow && doList) CALL(1, 0, 1, 1, 2);
1540  if ( doEnergy && doVirial && !doSlow && doList) CALL(1, 1, 0, 1, 2);
1541  if ( doEnergy && doVirial && doSlow && doList) CALL(1, 1, 1, 1, 2);
1542  } else {
1543  if (!doEnergy && !doVirial && !doSlow && !doList) CALL(0, 0, 0, 0, 3);
1544  if (!doEnergy && !doVirial && doSlow && !doList) CALL(0, 0, 1, 0, 3);
1545  if (!doEnergy && doVirial && !doSlow && !doList) CALL(0, 1, 0, 0, 3);
1546  if (!doEnergy && doVirial && doSlow && !doList) CALL(0, 1, 1, 0, 3);
1547  if ( doEnergy && !doVirial && !doSlow && !doList) CALL(1, 0, 0, 0, 3);
1548  if ( doEnergy && !doVirial && doSlow && !doList) CALL(1, 0, 1, 0, 3);
1549  if ( doEnergy && doVirial && !doSlow && !doList) CALL(1, 1, 0, 0, 3);
1550  if ( doEnergy && doVirial && doSlow && !doList) CALL(1, 1, 1, 0, 3);
1551 
1552  if (!doEnergy && !doVirial && !doSlow && doList) CALL(0, 0, 0, 1, 3);
1553  if (!doEnergy && !doVirial && doSlow && doList) CALL(0, 0, 1, 1, 3);
1554  if (!doEnergy && doVirial && !doSlow && doList) CALL(0, 1, 0, 1, 3);
1555  if (!doEnergy && doVirial && doSlow && doList) CALL(0, 1, 1, 1, 3);
1556  if ( doEnergy && !doVirial && !doSlow && doList) CALL(1, 0, 0, 1, 3);
1557  if ( doEnergy && !doVirial && doSlow && doList) CALL(1, 0, 1, 1, 3);
1558  if ( doEnergy && doVirial && !doSlow && doList) CALL(1, 1, 0, 1, 3);
1559  if ( doEnergy && doVirial && doSlow && doList) CALL(1, 1, 1, 1, 3);
1560  }
1561 
1562 #undef CALL
1563 }
1564 
1565 #endif // NAMD_AVXTILES
static Node * Object()
Definition: Node.h:86
int32_t int32
Definition: common.h:38
const int32 * get_full_exclusions_for_atom(int anum) const
Definition: Molecule.h:1225
const int32 * get_mod_exclusions_for_atom(int anum) const
Definition: Molecule.h:1227
const ExclusionCheck * get_excl_check_for_atom(int anum) const
Definition: Molecule.h:1241
#define EXCHCK_MOD
Definition: Molecule.h:87
Molecule stores the structural information for the system.
Definition: Molecule.h:175
#define EXCHCK_FULL
Definition: Molecule.h:86
int numAtoms
Definition: Molecule.h:585
Molecule * molecule
Definition: Node.h:179