NAMD
ComputeRestraintsCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #else // NAMD_HIP
8 #include <hip/hip_runtime.h>
9 #include <hipcub/hipcub.hpp>
10 #define cub hipcub
11 #endif // end NAMD_CUDA vs. NAMD_HIP
12 
13 #include "HipDefines.h"
14 
15 #include "ComputeRestraintsCUDAKernel.h"
16 
17 #ifdef NODEGROUP_FORCE_REGISTER
18 
19 
20 #define PI 3.141592653589793
21 
22 
23 // Host function to update the rotation matrix
24 void vec_rotation_matrix(double angle, double3 v, cudaTensor& m){
25 
26  double mag, s, c;
27  double xs, ys, zs, one_c;
28  s = sin(angle * PI/180.0);
29  c = cos(angle * PI/180.0);
30  xs = v.x * s;
31  ys = v.y * s;
32  zs = v.z * s;
33  one_c = 1.0 - c;
34 
35  mag = sqrt(v.x*v.x + v.y*v.y + v.z*v.z);
36 
37  if( mag == 0.0){
38  // Return a 3x3 identity matrix
39  m.xx = 1.0;
40  m.xy = 0.0;
41  m.xz = 0.0;
42  m.yx = 0.0;
43  m.yy = 1.0;
44  m.yz = 0.0;
45  m.zx = 0.0;
46  m.zy = 0.0;
47  m.zz = 1.0;
48  }
49 
50  m.xx = (one_c * (v.x * v.x) ) + c;
51  m.xy = (one_c * (v.x * v.y) ) - zs;
52  m.xz = (one_c * (v.z * v.x) ) + ys;
53 
54  m.yx = (one_c * (v.x * v.y) ) + zs;
55  m.yy = (one_c * (v.y * v.y) ) + c;
56  m.yz = (one_c * (v.y * v.z) ) - xs;
57 
58  m.zx = (one_c * (v.z * v.x) ) - ys;
59  m.zy = (one_c * (v.y * v.z) ) + xs;
60  m.zz = (one_c * (v.z * v.z) ) + c;
61 }
62 
63 
64 template<bool T_DOENERGY>
65 __global__ void computeRestrainingForceKernel(
66  const int currentTime,
67  const int nConstrainedAtoms,
68  const int consExp,
69  const double consScaling,
70  const bool movConsOn,
71  const bool rotConsOn,
72  const bool selConsOn,
73  const bool spheConsOn,
74  const bool consSelectX,
75  const bool consSelectY,
76  const bool consSelectZ,
77  const double rotVel,
78  const double3 rotAxis,
79  const double3 rotPivot,
80  const double3 moveVel,
81  const double3 spheConsCenter,
82  const int* __restrict d_constrainedSOA,
83  const int* __restrict d_constrainedID,
84  const double* __restrict d_pos_x,
85  const double* __restrict d_pos_y,
86  const double* __restrict d_pos_z,
87  const double* __restrict d_k,
88  const double* __restrict d_cons_x,
89  const double* __restrict d_cons_y,
90  const double* __restrict d_cons_z,
91  double* __restrict f_normal_x,
92  double* __restrict f_normal_y,
93  double* __restrict f_normal_z,
94  double* __restrict d_bcEnergy,
95  double* __restrict h_bcEnergy,
96  double3* __restrict d_netForce,
97  double3* __restrict h_netForce,
98  cudaTensor* __restrict d_virial,
99  cudaTensor* __restrict h_virial,
100  const Lattice lat,
101  unsigned int* __restrict tbcatomic,
102  cudaTensor rotMatrix
103 )
104 {
105 
106  int tid = threadIdx.x + (blockIdx.x * blockDim.x);
107 
108  int totaltb = gridDim.x;
109  bool isLastBlockDone;
110 
111  if(threadIdx.x == 0){
112  isLastBlockDone = 0;
113  }
114 
115  __syncthreads();
116 
117  double energy = 0;
118  double3 r_netForce = {0, 0, 0};
119  cudaTensor r_virial;
120  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
121  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
122  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
123 
124  if(tid < nConstrainedAtoms){
125 
126  // Index of the constrained atom in the SOA data structure
127  int indexC = d_constrainedID[tid];
128  int soaID = d_constrainedSOA[indexC];
129 
130  // Atomic fixed positions
131  double ref_x = d_cons_x[indexC];
132  double ref_y = d_cons_y[indexC];
133  double ref_z = d_cons_z[indexC];
134 
135  // JM: BAD BAD BAD -> UNCOALESCED GLOBAL MEMORY ACCESS
136  double pos_x = d_pos_x[soaID];
137  double pos_y = d_pos_y[soaID];
138  double pos_z = d_pos_z[soaID];
139 
140 
141  double k = d_k[indexC];
142  k *= consScaling;
143  // I can just store consScaling * k here instead of doing the math
144 
145  if(movConsOn){
146  ref_x += currentTime * moveVel.x;
147  ref_y += currentTime * moveVel.y;
148  ref_z += currentTime * moveVel.z;
149  }
150 
151  else if(rotConsOn){
152 
153  // do a matrix-vector operation
154 
155  double rx = ref_x - rotPivot.x;
156  double ry = ref_y - rotPivot.y;
157  double rz = ref_z - rotPivot.z;
158 
159  ref_x = rotMatrix.xx * rx + rotMatrix.xy * ry + rotMatrix.xz * rz;
160  ref_y = rotMatrix.yx * rx + rotMatrix.yy * ry + rotMatrix.yz * rz;
161  ref_z = rotMatrix.zx * rx + rotMatrix.zy * ry + rotMatrix.zz * rz;
162  }
163 
164  // END moving and rotationg contraints
165 
166  if(spheConsOn){
167  // JM: This code sucks, but maybe it's not a very common use-case, so let's go with it for now
168  double3 diff;
169  diff.x = ref_x - spheConsCenter.x;
170  diff.y = ref_y - spheConsCenter.y;
171  diff.z = ref_z - spheConsCenter.z;
172  // length of refCtr
173  double refRad = sqrt(diff.x * diff.x + diff.y*diff.y + diff.z * diff.z); // Whoops
174 
175  // Reusing diff here as relPos: first let's store global position - spherical center
176  diff = lat.delta(Vector(pos_x, pos_y, pos_z), spheConsCenter);
177 
178  refRad *= rsqrt(diff.x * diff.x + diff.y * diff.y + diff.z * diff.z); // 2x-whoops
179  // now we recalculate refPos:
180  ref_x = spheConsCenter.x + diff.x * refRad;
181  ref_y = spheConsCenter.y + diff.y * refRad;
182  ref_z = spheConsCenter.z + diff.z * refRad;
183  }
184 
185  // Calculating the RIJ vector as lattice.delta(ref, pos);
186  double3 rij;
187 
188  rij = lat.delta(Vector(ref_x, ref_y, ref_z), Vector(pos_x, pos_y, pos_z));
189  double3 vpos;
190  vpos.x = ref_x - rij.x;
191  vpos.y = ref_y - rij.y;
192  vpos.z = ref_z - rij.z;
193 
194  if(selConsOn){
195  rij.x *= (1.0 * consSelectX);
196  rij.y *= (1.0 * consSelectY);
197  rij.z *= (1.0 * consSelectZ);
198  }
199 
200  double r2 = rij.x * rij.x + rij.y*rij.y + rij.z*rij.z;
201  double r = sqrt(r2); // 3x-whoops
202 
203 
204  if (r > .0){
205  double value = k * (pow(r, consExp)); // NOTE: this consExp is an int, so it might be better to just do a loop
206  if (T_DOENERGY) energy = value;
207  value *= consExp;
208  value /= r2;
209  rij.x *= value;
210  rij.y *= value;
211  rij.z *= value;
212 
213  // JM: BAD BAD BAD ->UNCOALESCED GLOBAL MEMORY ACCESS
214  f_normal_x[soaID] += rij.x;
215  f_normal_y[soaID] += rij.y;
216  f_normal_z[soaID] += rij.z;
217  r_netForce.x = rij.x;
218  r_netForce.y = rij.y;
219  r_netForce.z = rij.z;
220 
221  // Now we calculate the virial contribution
222  // JM: is this virial symmetrical?
223  r_virial.xx = rij.x * vpos.x;
224  r_virial.xy = rij.x * vpos.y;
225  r_virial.xz = rij.x * vpos.z;
226  r_virial.yx = rij.y * vpos.x;
227  r_virial.yy = rij.y * vpos.y;
228  r_virial.yz = rij.y * vpos.z;
229  r_virial.zx = rij.z * vpos.x;
230  r_virial.zy = rij.z * vpos.y;
231  r_virial.zz = rij.z * vpos.z;
232  }
233  }
234 
235 #if 1
236  if(T_DOENERGY){
237  // Reduce energy and virials
238  typedef cub::BlockReduce<double, 128> BlockReduce;
239  __shared__ typename BlockReduce::TempStorage temp_storage;
240  energy = BlockReduce(temp_storage).Sum(energy);
241  __syncthreads();
242 
243  r_netForce.x = BlockReduce(temp_storage).Sum(r_netForce.x);
244  __syncthreads();
245  r_netForce.y = BlockReduce(temp_storage).Sum(r_netForce.y);
246  __syncthreads();
247  r_netForce.z = BlockReduce(temp_storage).Sum(r_netForce.z);
248  __syncthreads();
249  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
250  __syncthreads();
251  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
252  __syncthreads();
253  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
254  __syncthreads();
255  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
256  __syncthreads();
257  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
258  __syncthreads();
259  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
260  __syncthreads();
261  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
262  __syncthreads();
263  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
264  __syncthreads();
265  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
266  __syncthreads();
267 
268  if(threadIdx.x == 0){
269  atomicAdd(d_bcEnergy, energy);
270  atomicAdd(&(d_netForce->x), r_netForce.x);
271  atomicAdd(&(d_netForce->y), r_netForce.y);
272  atomicAdd(&(d_netForce->z), r_netForce.z);
273 
274  atomicAdd(&(d_virial->xx), r_virial.xx);
275  atomicAdd(&(d_virial->xy), r_virial.xy);
276  atomicAdd(&(d_virial->xz), r_virial.xz);
277 
278  atomicAdd(&(d_virial->yx), r_virial.yx);
279  atomicAdd(&(d_virial->yy), r_virial.yy);
280  atomicAdd(&(d_virial->yz), r_virial.yz);
281 
282  atomicAdd(&(d_virial->zx), r_virial.zx);
283  atomicAdd(&(d_virial->zy), r_virial.zy);
284  atomicAdd(&(d_virial->zz), r_virial.zz);
285 
286  __threadfence();
287  unsigned int value = atomicInc(tbcatomic, totaltb);
288  isLastBlockDone = (value == (totaltb -1));
289  }
290  }
291 #endif
292 
293  __syncthreads();
294 
295  if(isLastBlockDone){
296  if(threadIdx.x == 0){
297  //updates to host-mapped mem
298  h_bcEnergy[0] = d_bcEnergy[0];
299  h_netForce->x = d_netForce->x;
300  h_netForce->y = d_netForce->y;
301  h_netForce->z = d_netForce->z;
302 
303  h_virial->xx = d_virial->xx;
304  h_virial->xy = d_virial->xy;
305  h_virial->xz = d_virial->xz;
306 
307  h_virial->yx = d_virial->yx;
308  h_virial->yy = d_virial->yy;
309  h_virial->yz = d_virial->yz;
310 
311  h_virial->zx = d_virial->zx;
312  h_virial->zy = d_virial->zy;
313  h_virial->zz = d_virial->zz;
314 
315  d_bcEnergy[0] = 0;
316  d_netForce->x = 0;
317  d_netForce->y = 0;
318  d_netForce->z = 0;
319 
320  d_virial->xx = 0;
321  d_virial->xy = 0;
322  d_virial->xz = 0;
323 
324  d_virial->yx = 0;
325  d_virial->yy = 0;
326  d_virial->yz = 0;
327 
328  d_virial->zx = 0;
329  d_virial->zy = 0;
330  d_virial->zz = 0;
331 
332  tbcatomic[0] = 0;
333  __threadfence();
334  }
335  }
336 }
337 
338 void computeRestrainingForce(
339  const int doEnergy,
340  const int doVirial,
341  const int currentTime,
342  const int nConstrainedAtoms,
343  const int consExp,
344  const double consScaling,
345  const bool movConsOn,
346  const bool rotConsOn,
347  const bool selConsOn,
348  const bool spheConsOn,
349  const bool consSelectX,
350  const bool consSelectY,
351  const bool consSelectZ,
352  const double rotVel,
353  const double3 rotAxis,
354  const double3 rotPivot,
355  const double3 moveVel,
356  const double3 spheConsCenter,
357  const int* d_constrainedSOA,
358  const int* d_constrainedID,
359  const double* d_pos_x,
360  const double* d_pos_y,
361  const double* d_pos_z,
362  const double* d_k,
363  const double* d_cons_x,
364  const double* d_cons_y,
365  const double* d_cons_z,
366  double* d_f_normal_x,
367  double* d_f_normal_y,
368  double* d_f_normal_z,
369  double* d_bcEnergy,
370  double* h_bcEnergy,
371  double3* d_netForce,
372  double3* h_netForce,
373  const Lattice* lat,
374  cudaTensor* d_virial,
375  cudaTensor* h_virial,
376  cudaTensor rotationMatrix,
377  unsigned int* d_tbcatomic,
378  cudaStream_t stream
379 ){
380 
381  const int blocks = 128;
382  const int grid = (nConstrainedAtoms + blocks - 1) / blocks;
383 
384  // we calculate the rotational matrix for this timestep on the host, hopefully this is fast enough
385  vec_rotation_matrix(rotVel * currentTime, rotAxis, rotationMatrix);
386 
387  if(doEnergy || doVirial){
388  computeRestrainingForceKernel<true> <<<grid, blocks, 0, stream >>>(
389  currentTime,
390  nConstrainedAtoms,
391  consExp,
392  consScaling,
393  movConsOn,
394  rotConsOn,
395  selConsOn,
396  spheConsOn,
397  consSelectX,
398  consSelectY,
399  consSelectZ,
400  rotVel,
401  rotAxis,
402  rotPivot,
403  moveVel,
404  spheConsCenter,
405  d_constrainedSOA,
406  d_constrainedID,
407  d_pos_x,
408  d_pos_y,
409  d_pos_z,
410  d_k,
411  d_cons_x,
412  d_cons_y,
413  d_cons_z,
414  d_f_normal_x,
415  d_f_normal_y,
416  d_f_normal_z,
417  d_bcEnergy,
418  h_bcEnergy,
419  d_netForce,
420  h_netForce,
421  d_virial,
422  h_virial,
423  *lat,
424  d_tbcatomic,
425  rotationMatrix);
426 
427  }else {
428  computeRestrainingForceKernel <false> <<<grid, blocks, 0, stream>>>(
429  currentTime,
430  nConstrainedAtoms,
431  consExp,
432  consScaling,
433  movConsOn,
434  rotConsOn,
435  selConsOn,
436  spheConsOn,
437  consSelectX,
438  consSelectY,
439  consSelectZ,
440  rotVel,
441  rotAxis,
442  rotPivot,
443  moveVel,
444  spheConsCenter,
445  d_constrainedSOA,
446  d_constrainedID,
447  d_pos_x,
448  d_pos_y,
449  d_pos_z,
450  d_k,
451  d_cons_x,
452  d_cons_y,
453  d_cons_z,
454  d_f_normal_x,
455  d_f_normal_y,
456  d_f_normal_z,
457  d_bcEnergy,
458  h_bcEnergy,
459  d_netForce,
460  h_netForce,
461  d_virial,
462  h_virial,
463  *lat,
464  d_tbcatomic,
465  rotationMatrix);
466  }
467 }
468 
469 #endif // NODEGROUP_FORCE_REGISTER