NAMD
AVXTileLists.h
Go to the documentation of this file.
1 #ifndef AVXTILELISTS_H
2 #define AVXTILELISTS_H
3 
4 #include "AVXTiles.h"
5 #include "Lattice.h"
6 
7 #ifdef NAMD_AVXTILES
8 #include <immintrin.h>
9 
10 //
11 // Replace missing macros and intrinsics when building with gcc/clang
12 //
13 #if defined __INTEL_COMPILER || defined __INTEL_LLVM_COMPILER
14 
15 #define FORCEINLINE __forceinline
16 
17 #else // fixes for non-Intel compilers
18 
19 #define FORCEINLINE inline
20 
21 #ifndef _MM_SCALE_8
22 #define _MM_SCALE_8 8
23 #endif
24 
25 #ifndef _MM_SCALE_4
26 #define _MM_SCALE_4 4
27 #endif
28 
29 #ifndef _popcnt32
30 #define _popcnt32 _mm_popcnt_u32
31 #endif
32 
33 // _mm512_invsqrt_ps from SVML is unavailable, so provide our own version
34 // based on rsqrt14 14-bit approximation for reciprocal square root.
35 // Perform one iteration of Newtonr-Raphson to get full 23-bit precision.
36 static inline __m512 _mm512_invsqrt_ps(__m512 a){
37  __m512 xn = _mm512_rsqrt14_ps(a);
38  __m512 xn1 = _mm512_set1_ps(3.0f) - _mm512_mul_ps(a, _mm512_mul_ps(xn, xn));
39  return _mm512_mul_ps(_mm512_set1_ps(0.5f), _mm512_mul_ps(xn, xn1));
40 }
41 
42 #ifndef _mm512_i32logather_pd
43 #define _mm512_i32logather_pd(vindex, base_addr, scale) _mm512_i32gather_pd(_mm512_extracti32x8_epi32((vindex), 0), (base_addr), (scale))
44 #endif
45 
46 #ifndef _mm512_mask_i32logather_pd
47 #define _mm512_mask_i32logather_pd(src, mask, vindex, base_addr, scale) _mm512_mask_i32gather_pd((src), (mask), _mm512_extracti32x8_epi32((vindex), 0), (base_addr), (scale))
48 #endif
49 
50 #ifndef _mm512_i32loscatter_pd
51 #define _mm512_i32loscatter_pd(base_addr, vindex, a, scale) _mm512_i32scatter_pd((base_addr), _mm512_extracti32x8_epi32((vindex), 0), (a), (scale))
52 #endif
53 
54 #ifndef _mm512_mask_i32loscatter_pd
55 #define _mm512_mask_i32loscatter_pd(base_addr, k, vindex, a, scale) _mm512_mask_i32scatter_pd((base_addr), (k), _mm512_extracti32x8_epi32((vindex), 0), (a), (scale))
56 #endif
57 
58 #endif // fixes for non-Intel compilers
59 
60 // --------------------------------------------------------------------------
61 // Preprocessor defines for hybrid tiles/pair list
62 // - Disable with NAMD_AVXTILES_PAIR_THRESHOLD=0 and no ORDER_PATCHES
63 // --------------------------------------------------------------------------
64 // Threshold in number of atoms in i-tile to switch to pair list
65 #define NAMD_AVXTILES_PAIR_THRESHOLD 4
66 // Initial allocation size for number of neighbors per atom in pair list
67 #define NAMD_AVXTILES_IPAIRCOUNT 300
68 // Order patch pair so first patch is one with smallest # atoms in last tile
69 #define NAMD_AVXTILES_ORDER_PATCHES
70 // --------------------------------------------------------------------------
71 
72 // Class for storing data for neighbor "j" tiles packed for all atoms in patch
73 class AVXJTiles {
74  public:
75  AVXJTiles();
76  ~AVXJTiles();
77 
78  inline int numTiles() const { return _numTiles; }
79  inline int maxTiles() const { return _numTilesAlloc; }
80  inline bool realloc(const int n) {
81  _numTiles = n;
82  if (n>_numTilesAlloc) {
83  _realloc();
84  return true;
85  } else
86  return false;
87  }
88 
89  // Bitwise exclusion data for neighbor tiles
90  unsigned int *excl;
91  // Starting index for atoms in neighbor tile
92  int *atomStart;
93  // Used for deleting empty tile lists on build steps
94  int *status;
95 
96  private:
97  int _numTiles, _numTilesAlloc;
98  void _realloc();
99 };
100 
101 
102 // Data and routines for storing/building tile lists and computing forces,
103 // virials, and energies from this data for patch pairs (including self).
104 // - Modified and excluded pairs are also processed here with separate loops
105 class AVXTileLists {
106  public:
107  struct List {
108  int atomStart_i;
109  int jtileStart;
110  };
111 
112  AVXTileLists();
113  ~AVXTileLists();
114 
115  // Simulation parameters are passed to be explicit about data used for
116  // tiles algorithm. Interpolation mode is set and documented in
117  // ComputeNonbondedUtil.
118  void setSimParams(const float scale, const float scale14, const float c1,
119  const float c3, const float switchOn2, float *fastTable,
120  float *fastEnergyTable, float *slowTable,
121  float *slowEnergyTable, float *eps4sigma,
122  float *eps4sigma14, float *ljTable,
123  const float ljTableWidth, float *modifiedTable,
124  float *modifiedEnergyTable, float *excludedTable,
125  float *excludedEnergyTable, const int interpolationMode);
126 
127  inline void atomUpdate(AVXTiles *patch0tiles, AVXTiles *patch1tiles) {
128  tiles_p0 = patch0tiles;
129  tiles_p1 = patch1tiles;
130 
131  // Patch reordering currently doesn't help perf unless using hybrid pairs
132  #ifdef NAMD_AVXTILES_ORDER_PATCHES
133  _patchOrder0 = 0;
134  _patchOrder1 = 1;
135  bool reorder = false;
136  const int rem0 = patch0tiles->numAtoms() & 15;
137  const int rem1 = patch1tiles->numAtoms() & 15;
138  if (rem1 && rem1 <= NAMD_AVXTILES_PAIR_THRESHOLD && rem1 < rem0)
139  reorder = true;
140  else if ((rem0 > NAMD_AVXTILES_PAIR_THRESHOLD || rem0 == 0) &&
141  patch1tiles->numAtoms() < patch0tiles->numAtoms())
142  reorder = true;
143  if (reorder) {
144  tiles_p0 = patch1tiles;
145  tiles_p1 = patch0tiles;
146  _patchOrder0 = 1;
147  _patchOrder1 = 0;
148  }
149  #endif
150 
151  realloc(tiles_p0->numTiles());
152  }
153 
154  inline void updateParams(const Lattice &lattice, const Vector &offset,
155  const double cutoff) {
156  _cutoff2 = cutoff * cutoff;
157  _paramMinvCut3 = -1.0 / (_cutoff2 * sqrt(_cutoff2));
158  _paramCutUnder3 = 3.0 / sqrt(_cutoff2);
159  _shx = offset.x*lattice.a().x + offset.y*lattice.b().x +
160  offset.z*lattice.c().x;
161  _shy = offset.x*lattice.a().y + offset.y*lattice.b().y +
162  offset.z*lattice.c().y;
163  _shz = offset.x*lattice.a().z + offset.y*lattice.b().z +
164  offset.z*lattice.c().z;
165  }
166 
167  inline void updateBuildInfo(const int step, const int minPart,
168  const int maxPart, const int numParts,
169  const double plcutoff) {
170  _lastBuild = step;
171  _minPart = minPart;
172  _maxPart = maxPart;
173  _numParts = numParts;
174  _plcutoff2 = plcutoff * plcutoff;
175  }
176 
177  inline int numLists() const { return _numLists; }
178  // Reallocate data for storing tile lists
179  inline void realloc(const int numLists) {
180  _numLists = numLists;
181  if (numLists > _numListsAlloc) _realloc();
182  }
183  // Reallocate data for storing modified pairs
184  inline void reallocModified(const int numModified) {
185  _numModified = numModified;
186  if (numModified > _numModifiedAlloc) _reallocModified();
187  }
188  // Reallocate data for storing excluded pairs
189  inline void reallocExcluded(const int numExcluded) {
190  _numExcluded = numExcluded;
191  if (numExcluded > _numExcludedAlloc) _reallocExcluded();
192  }
193  // Reallocate data for pair lists when using hybrid tile / pairlists
194  inline void reallocPairLists(const int numPairLists, const int maxPairs) {
195  if (numPairLists > _maxPairLists || maxPairs > _maxPairs)
196  _reallocPairLists(numPairLists, maxPairs);
197  }
198 
199  inline int exclusionChecksum() const { return _exclusionChecksum; }
200  inline float energyVdw() const { return _energyVdw; }
201  inline float energyElec() const { return _energyElec; }
202  inline float energySlow() const { return _energySlow; }
203  inline float virialXX() const { return _fNet_x * _shx; }
204  inline float virialXY() const { return _fNet_x * _shy; }
205  inline float virialXZ() const { return _fNet_x * _shz; }
206  inline float virialYY() const { return _fNet_y * _shy; }
207  inline float virialYZ() const { return _fNet_y * _shz; }
208  inline float virialZZ() const { return _fNet_z * _shz; }
209  inline float virialSlowXX() const { return _fNetSlow_x * _shx; }
210  inline float virialSlowXY() const { return _fNetSlow_x * _shy; }
211  inline float virialSlowXZ() const { return _fNetSlow_x * _shz; }
212  inline float virialSlowYY() const { return _fNetSlow_y * _shy; }
213  inline float virialSlowYZ() const { return _fNetSlow_y * _shz; }
214  inline float virialSlowZZ() const { return _fNetSlow_z * _shz; }
215 
216  #ifdef NAMD_AVXTILES_ORDER_PATCHES
217  inline int patchOrder0() const { return _patchOrder0; }
218  inline int patchOrder1() const { return _patchOrder1; }
219  #else
220  inline int patchOrder0() const { return 0; }
221  inline int patchOrder1() const { return 1; }
222  #endif
223 
224  // Build bounding boxes for tiles on both patches and build initial tile
225  // lists based on bounding boxes
226  // -- Paritioning for LB is based on number of neighbor tiles
227  void build();
228  // On build steps, delete any empty tile lists after refinement in force
229  // calculation based on atom distances.
230  void delEmptyLists();
231  // Calculate forces, virials, energies for tile lists, pair lists, and
232  // modified/excluded pairs
233  void nbForceAVX512(const int doEnergy, const int doVirial, const int doList,
234  const int doSlow);
235 
236  List *lists;
237  // Number of tile neighbors
238  unsigned int *listDepth;
239 
240  // Tiles data for each patch in pair
241  AVXTiles *tiles_p0, *tiles_p1;
242  // Neighbor tile data
243  AVXJTiles jTiles;
244 
245  private:
246  template <bool count, bool partitionMode>
247  int _buildBB();
248 
249  template <bool doEnergy, bool doVirial, bool doSlow,
250  bool doList, int interpMode>
251  FORCEINLINE void nbAVX512Tiles(__m512 &energyVdw, __m512 &energyElec,
252  __m512 &energySlow, __m512 &fNet_x,
253  __m512 &fNet_y, __m512 &fNet_z,
254  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
255  __m512 &fNetSlow_z);
256  template <bool doEnergy, bool doVirial, bool doSlow, int interpMode>
257  FORCEINLINE void nbAVX512Pairs(__m512 &energyVdw, __m512 &energyElec,
258  __m512 &energySlow, __m512 &fNet_x,
259  __m512 &fNet_y, __m512 &fNet_z,
260  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
261  __m512 &fNetSlow_z);
262  template <bool doEnergy, bool doVirial, bool doSlow, int interpMode>
263  inline void nbAVX512Modified(__m512 &energyVdw, __m512 &energyElec,
264  __m512 &energySlow, __m512 &fNet_x,
265  __m512 &fNet_y, __m512 &fNet_z,
266  __m512 &fNetSlow_x, __m512 &fNetSlow_y,
267  __m512 &fNetSlow_z);
268  template <bool doEnergy, bool doVirial>
269  inline void nbAVX512Excluded(__m512 &energySlow, __m512 &fNetSlow_x,
270  __m512 &fNetSlow_y, __m512 &fNetSlow_z);
271 
272  template <bool doEnergy, bool doVirial, bool doSlow,
273  bool doList, int interpMode>
274  void doAll();
275 
276  void _realloc();
277  void _reallocModified();
278  void _reallocExcluded();
279  void _reallocPairLists(const int numPairLists, const int maxPairs);
280 
281  float _cutoff2, _plcutoff2;
282 
283  float *_paramSlowTable, *_paramSlowEnergyTable;
284  // -------------- NOT USED WITH INTERPOLATION MODES 2 and 3
285  float *_paramEps4Sigma, *_paramEps4Sigma14;
286  // -------------- NOT USED WITH INTERPOLATION MODE 3
287  float _paramMinvCut3, _paramCutUnder3;
288  // --------------
289  float *_paramModifiedTable, *_paramModifiedEnergyTable;
290  float *_paramExcludedTable, *_paramExcludedEnergyTable;
291 
292  // -------------- NOT USED WITH INTERPOLATION MODES 1 and 2
293  float *_paramFastTable, *_paramFastEnergyTable;
294  const float *_paramLjTable;
295  int _paramLjWidth;
296  // --------------
297 
298  float _shx, _shy, _shz;
299  float _paramScale, _paramScale14, _paramC1, _paramC3, _paramSwitchOn2;
300  int _numLists, _numListsAlloc;
301  int _numModified, _numModifiedAlloc, _numExcluded, _numExcludedAlloc;
302  int *_modified_i, *_modified_j, *_excluded_i, *_excluded_j;
303 
304  int _numPairLists, _maxPairLists, _maxPairs;
305  int *_pair_i, *_numPairs, *_pairStart, *_pairLists;
306 
307  float _fNet_x, _fNet_y, _fNet_z, _fNetSlow_x, _fNetSlow_y, _fNetSlow_z;
308  int _exclusionChecksum;
309  float _energyVdw, _energyElec, _energySlow;
310 
311  int _interpolationMode, _minPart, _maxPart, _numParts, _lastBuild;
312 
313  #ifdef NAMD_AVXTILES_ORDER_PATCHES
314  int _patchOrder0, _patchOrder1;
315  #endif
316 
317  #ifndef MEM_OPT_VERSION
318  char * _exclFlyListBuffer;
319  char * _exclFlyLists[16];
320  const int32 * _fullExcl[16], * _modExcl[16];
321  int _lastFlyListTile;
322  const char * buildExclFlyList(const int itileList, const int z,
323  const __m512i &atomIndex_i, const int n,
324  void *mol);
325  #endif
326 };
327 
328 #endif // NAMD_AVXTILES
329 #endif // AVXTILELISTS_H
NAMD_HOST_DEVICE Vector c() const
Definition: Lattice.h:270
Definition: Vector.h:72
int32_t int32
Definition: common.h:38
BigReal z
Definition: Vector.h:74
BigReal x
Definition: Vector.h:74
NAMD_HOST_DEVICE Vector b() const
Definition: Lattice.h:269
BigReal y
Definition: Vector.h:74
NAMD_HOST_DEVICE Vector a() const
Definition: Lattice.h:268