NAMD
ComputeSMDCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeSMDCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 
22 /*! Calculate SMD force and virial for large atom group (numSMDAtoms > 1024)
23  Multiple thread block will be called to do this operation.
24  The current COM (curCOM) must be calculated and pssed to this function. */
25 template<bool T_DOENERGY>
26 __global__ void computeSMDForceWithCOMKernel(
27  const int numSMDAtoms,
28  const Lattice lat,
29  const double inv_group_mass,
30  const double k,
31  const double k2,
32  const double velocity,
33  const double3 direction,
34  const int currentTime,
35  const double3 origCM,
36  const float * __restrict mass,
37  const double* __restrict pos_x,
38  const double* __restrict pos_y,
39  const double* __restrict pos_z,
40  const char3* __restrict transform,
41  double* __restrict f_normal_x,
42  double* __restrict f_normal_y,
43  double* __restrict f_normal_z,
44  const int* __restrict smdAtomsSOAIndex,
45  cudaTensor* __restrict d_virial,
46  const double3* __restrict curCOM,
47  double* __restrict h_extEnergy,
48  double3* __restrict h_extForce,
49  cudaTensor* __restrict h_extVirial,
50  unsigned int* __restrict tbcatomic)
51 {
52  int tid = threadIdx.x + blockIdx.x * blockDim.x;
53  int totaltb = gridDim.x;
54  bool isLastBlockDone = 0;
55 
56  double3 group_f = {0, 0, 0};
57  double energy = 0.0;
58  double3 pos = {0, 0, 0};
59  double3 f = {0, 0, 0};
60  cudaTensor r_virial;
61  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
62  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
63  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
64  int SOAindex;
65  if(tid < numSMDAtoms){
66  // First -> recalculate center of mass.
67  // Only thread zero is doing this
68  SOAindex = smdAtomsSOAIndex[tid];
69 
70  // uncoalesced memory access: too bad!
71  double m = mass[SOAindex]; // Cast from float to double here
72  pos.x = pos_x[SOAindex];
73  pos.y = pos_y[SOAindex];
74  pos.z = pos_z[SOAindex];
75 
76  // calculate the distance difference along direction
77  double3 diffCOM;
78  diffCOM.x = curCOM->x - origCM.x;
79  diffCOM.y = curCOM->y - origCM.y;
80  diffCOM.z = curCOM->z - origCM.z;
81  double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
82  diffCOM.z*direction.z;
83 
84  // Ok so we've calculated the new center of mass, now we can calculate the bias
85  double preFactor = (velocity*currentTime - diff);
86  group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
87  group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
88  group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
89 
90  // calculate the force on each atom
91  f.x = group_f.x * m * inv_group_mass;
92  f.y = group_f.y * m * inv_group_mass;
93  f.z = group_f.z * m * inv_group_mass;
94 
95  // apply the bias
96  f_normal_x[SOAindex] += f.x ;
97  f_normal_y[SOAindex] += f.y ;
98  f_normal_z[SOAindex] += f.z ;
99  if(T_DOENERGY){
100  // energy for restraint along the direction
101  energy = 0.5*k*preFactor*preFactor;
102  // energy for transverse restraint
103  energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
104  diffCOM.z*diffCOM.z - diff*diff);
105  // unwrap coordinates before calculating the virial
106  char3 t = transform[SOAindex];
107  pos = lat.reverse_transform(pos, t);
108  r_virial.xx = f.x * pos.x;
109  r_virial.xy = f.x * pos.y;
110  r_virial.xz = f.x * pos.z;
111  r_virial.yx = f.y * pos.x;
112  r_virial.yy = f.y * pos.y;
113  r_virial.yz = f.y * pos.z;
114  r_virial.zx = f.z * pos.x;
115  r_virial.zy = f.z * pos.y;
116  r_virial.zz = f.z * pos.z;
117  }
118  }
119  __syncthreads();
120 
121  if(T_DOENERGY){
122  typedef cub::BlockReduce<double, 128> BlockReduce;
123  __shared__ typename BlockReduce::TempStorage temp_storage;
124 
125  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
126  __syncthreads();
127  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
128  __syncthreads();
129  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
130  __syncthreads();
131 
132  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
133  __syncthreads();
134  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
135  __syncthreads();
136  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
137  __syncthreads();
138 
139  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
140  __syncthreads();
141  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
142  __syncthreads();
143  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
144  __syncthreads();
145 
146  if(threadIdx.x == 0){
147  atomicAdd(&(d_virial->xx), r_virial.xx);
148  atomicAdd(&(d_virial->xy), r_virial.xy);
149  atomicAdd(&(d_virial->xz), r_virial.xz);
150 
151  atomicAdd(&(d_virial->yx), r_virial.yx);
152  atomicAdd(&(d_virial->yy), r_virial.yy);
153  atomicAdd(&(d_virial->yz), r_virial.yz);
154 
155  atomicAdd(&(d_virial->zx), r_virial.zx);
156  atomicAdd(&(d_virial->zy), r_virial.zy);
157  atomicAdd(&(d_virial->zz), r_virial.zz);
158 
159  __threadfence();
160  unsigned int value = atomicInc(&tbcatomic[0], totaltb);
161  isLastBlockDone = (value == (totaltb -1));
162  }
163 
164  __syncthreads();
165  // Last block will set the host values
166  if(isLastBlockDone){
167  if(threadIdx.x == 0){
168  h_extEnergy[0] = energy;
169  h_extForce->x = group_f.x;
170  h_extForce->y = group_f.y;
171  h_extForce->z = group_f.z;
172 
173  h_extVirial->xx = d_virial->xx;
174  h_extVirial->xy = d_virial->xy;
175  h_extVirial->xz = d_virial->xz;
176  h_extVirial->yx = d_virial->yx;
177  h_extVirial->yy = d_virial->yy;
178  h_extVirial->yz = d_virial->yz;
179  h_extVirial->zx = d_virial->zx;
180  h_extVirial->zy = d_virial->zy;
181  h_extVirial->zz = d_virial->zz;
182 
183  //reset the device virial value
184  d_virial->xx = 0;
185  d_virial->xy = 0;
186  d_virial->xz = 0;
187 
188  d_virial->yx = 0;
189  d_virial->yy = 0;
190  d_virial->yz = 0;
191 
192  d_virial->zx = 0;
193  d_virial->zy = 0;
194  d_virial->zz = 0;
195  //resets atomic counter
196  tbcatomic[0] = 0;
197  __threadfence();
198  }
199  }
200  }
201 }
202 
203 
204 /*! Calculate SMD force, virial, and COM for small atom group (numSMDAtoms <= 1024)
205  Single thread block will be called to do this operation.
206  The current COM will be calculated and stored in h_curCM. */
207 template<bool T_DOENERGY>
208 __global__ void computeSMDForceKernel(
209  const int numSMDAtoms,
210  const Lattice lat,
211  const double inv_group_mass,
212  const double k,
213  const double k2,
214  const double velocity,
215  const double3 direction,
216  const int currentTime,
217  const double3 origCM,
218  const float * __restrict mass,
219  const double* __restrict pos_x,
220  const double* __restrict pos_y,
221  const double* __restrict pos_z,
222  const char3* __restrict transform,
223  double* __restrict f_normal_x,
224  double* __restrict f_normal_y,
225  double* __restrict f_normal_z,
226  const int* __restrict smdAtomsSOAIndex,
227  double3* __restrict h_curCM,
228  double* __restrict h_extEnergy,
229  double3* __restrict h_extForce,
230  cudaTensor* __restrict h_extVirial)
231 {
232  __shared__ double3 group_f;
233  __shared__ double energy;
234  int tid = threadIdx.x + blockIdx.x * blockDim.x;
235  double m = 0;
236  double3 cm = {0, 0, 0};
237  double3 pos = {0, 0, 0};
238  double3 f = {0, 0, 0};
239  cudaTensor r_virial;
240  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
241  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
242  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
243  int SOAindex;
244  if(tid < numSMDAtoms){
245  // First -> recalculate center of mass.
246  // Only thread zero is doing this
247  SOAindex = smdAtomsSOAIndex[tid];
248 
249  // uncoalesced memory access: too bad!
250  m = mass[SOAindex]; // Cast from float to double here
251  pos.x = pos_x[SOAindex];
252  pos.y = pos_y[SOAindex];
253  pos.z = pos_z[SOAindex];
254 
255  // unwrap the coordinate to calculate COM
256  char3 t = transform[SOAindex];
257  pos = lat.reverse_transform(pos, t);
258 
259  cm.x = pos.x * m;
260  cm.y = pos.y * m;
261  cm.z = pos.z * m;
262  }
263 
264  // now reduce the values and add it to thread zero
265  typedef cub::BlockReduce<double, 1024> BlockReduce;
266  __shared__ typename BlockReduce::TempStorage temp_storage;
267 
268  cm.x = BlockReduce(temp_storage).Sum(cm.x);
269  __syncthreads();
270  cm.y = BlockReduce(temp_storage).Sum(cm.y);
271  __syncthreads();
272  cm.z = BlockReduce(temp_storage).Sum(cm.z);
273  __syncthreads();
274 
275  // Calculate group force and acceleration
276  if(threadIdx.x == 0){
277  cm.x *= inv_group_mass; // calculates the current center of mass
278  cm.y *= inv_group_mass; // calculates the current center of mass
279  cm.z *= inv_group_mass; // calculates the current center of mass
280 
281  // calculate the distance difference along direction
282  double3 diffCOM;
283  diffCOM.x = cm.x - origCM.x;
284  diffCOM.y = cm.y - origCM.y;
285  diffCOM.z = cm.z - origCM.z;
286  double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
287  diffCOM.z*direction.z;
288 
289  // Ok so we've calculated the new center of mass, now we can calculate the bias
290  double preFactor = (velocity*currentTime - diff);
291  group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
292  group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
293  group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
294  if(T_DOENERGY) {
295  // energy for restraint along the direction
296  energy = 0.5*k*preFactor*preFactor;
297  // energy for transverse restraint
298  energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
299  diffCOM.z*diffCOM.z - diff*diff);
300  }
301  }
302  __syncthreads();
303 
304  if(tid < numSMDAtoms){
305  // calculate the force on each atom
306  f.x = group_f.x * m * inv_group_mass;
307  f.y = group_f.y * m * inv_group_mass;
308  f.z = group_f.z * m * inv_group_mass;
309 
310  // apply the bias
311  f_normal_x[SOAindex] += f.x ;
312  f_normal_y[SOAindex] += f.y ;
313  f_normal_z[SOAindex] += f.z ;
314  if(T_DOENERGY){
315  r_virial.xx = f.x * pos.x;
316  r_virial.xy = f.x * pos.y;
317  r_virial.xz = f.x * pos.z;
318  r_virial.yx = f.y * pos.x;
319  r_virial.yy = f.y * pos.y;
320  r_virial.yz = f.y * pos.z;
321  r_virial.zx = f.z * pos.x;
322  r_virial.zy = f.z * pos.y;
323  r_virial.zz = f.z * pos.z;
324  }
325  }
326  if(T_DOENERGY){
327  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
328  __syncthreads();
329  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
330  __syncthreads();
331  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
332  __syncthreads();
333 
334  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
335  __syncthreads();
336  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
337  __syncthreads();
338  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
339  __syncthreads();
340 
341  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
342  __syncthreads();
343  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
344  __syncthreads();
345  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
346  __syncthreads();
347 
348  if(threadIdx.x == 0){
349  // thread zero updates the value
350  h_curCM->x = cm.x; // update current center of mass
351  h_curCM->y = cm.y; // update current center of mass
352  h_curCM->z = cm.z; // update current center of mass
353 
354  h_extEnergy[0] = energy; // bias energy
355  h_extForce->x = group_f.x; // bias force
356  h_extForce->y = group_f.y;
357  h_extForce->z = group_f.z;
358 
359  h_extVirial->xx = r_virial.xx;
360  h_extVirial->xy = r_virial.xy;
361  h_extVirial->xz = r_virial.xz;
362  h_extVirial->yx = r_virial.yx;
363  h_extVirial->yy = r_virial.yy;
364  h_extVirial->yz = r_virial.yz;
365  h_extVirial->zx = r_virial.zx;
366  h_extVirial->zy = r_virial.zy;
367  h_extVirial->zz = r_virial.zz;
368  }
369  }
370 }
371 
372 /*! Compute SMD force and virial on group of atoms */
373 void computeSMDForce(
374  const Lattice &lat,
375  const double inv_group_mass,
376  const double spring_constant,
377  const double transverse_spring_constant,
378  const double velocity,
379  const double3 direction,
380  const int doEnergy,
381  const int currentTime,
382  const double3 origCM,
383  const float* d_mass,
384  const double* d_pos_x,
385  const double* d_pos_y,
386  const double* d_pos_z,
387  const char3* d_transform,
388  double * d_f_normal_x,
389  double * d_f_normal_y,
390  double * d_f_normal_z,
391  const int numSMDAtoms,
392  const int* d_smdAtomsSOAIndex,
393  double3* d_curCM,
394  double3* h_curCM,
395  cudaTensor* d_extVirial,
396  double* h_extEnergy,
397  double3* h_extForce,
398  cudaTensor* h_extVirial,
399  unsigned int* d_tbcatomic,
400  cudaStream_t stream)
401 {
402  if (numSMDAtoms > 1024) {
403  const int blocks = 128;
404  const int grid = (numSMDAtoms + blocks - 1) / blocks;
405  //first calculate the COM for SMD group and store it in h_curCM
406  computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
407  numSMDAtoms,
408  inv_group_mass,
409  lat,
410  d_mass,
411  d_pos_x,
412  d_pos_y,
413  d_pos_z,
414  d_transform,
415  d_smdAtomsSOAIndex,
416  d_curCM,
417  h_curCM,
418  d_tbcatomic);
419 
420  if(doEnergy){
421  computeSMDForceWithCOMKernel<true><<<grid, blocks, 0, stream>>>(
422  numSMDAtoms,
423  lat,
424  inv_group_mass,
425  spring_constant,
426  transverse_spring_constant,
427  velocity,
428  direction,
429  currentTime,
430  origCM,
431  d_mass,
432  d_pos_x,
433  d_pos_y,
434  d_pos_z,
435  d_transform,
436  d_f_normal_x,
437  d_f_normal_y,
438  d_f_normal_z,
439  d_smdAtomsSOAIndex,
440  d_extVirial,
441  h_curCM,
442  h_extEnergy,
443  h_extForce,
444  h_extVirial,
445  d_tbcatomic);
446  } else {
447  computeSMDForceWithCOMKernel<false><<<grid, blocks, 0, stream>>>(
448  numSMDAtoms,
449  lat,
450  inv_group_mass,
451  spring_constant,
452  transverse_spring_constant,
453  velocity,
454  direction,
455  currentTime,
456  origCM,
457  d_mass,
458  d_pos_x,
459  d_pos_y,
460  d_pos_z,
461  d_transform,
462  d_f_normal_x,
463  d_f_normal_y,
464  d_f_normal_z,
465  d_smdAtomsSOAIndex,
466  d_extVirial,
467  h_curCM,
468  h_extEnergy,
469  h_extForce,
470  h_extVirial,
471  d_tbcatomic);
472  }
473  } else {
474  // SMD is usually comprised of a small number of atoms. So we can get away with launching
475  // a single threadblock
476  const int blocks = 1024;
477  const int grid = 1;
478 
479  if(doEnergy){
480  computeSMDForceKernel<true><<<grid, blocks, 0, stream>>>(
481  numSMDAtoms,
482  lat,
483  inv_group_mass,
484  spring_constant,
485  transverse_spring_constant,
486  velocity,
487  direction,
488  currentTime,
489  origCM,
490  d_mass,
491  d_pos_x,
492  d_pos_y,
493  d_pos_z,
494  d_transform,
495  d_f_normal_x,
496  d_f_normal_y,
497  d_f_normal_z,
498  d_smdAtomsSOAIndex,
499  h_curCM,
500  h_extEnergy,
501  h_extForce,
502  h_extVirial);
503  }else{
504  computeSMDForceKernel<false><<<grid, blocks, 0, stream>>>(
505  numSMDAtoms,
506  lat,
507  inv_group_mass,
508  spring_constant,
509  transverse_spring_constant,
510  velocity,
511  direction,
512  currentTime,
513  origCM,
514  d_mass,
515  d_pos_x,
516  d_pos_y,
517  d_pos_z,
518  d_transform,
519  d_f_normal_x,
520  d_f_normal_y,
521  d_f_normal_z,
522  d_smdAtomsSOAIndex,
523  h_curCM,
524  h_extEnergy,
525  h_extForce,
526  h_extVirial);
527  }
528  }
529 }
530 
531 #endif // NODEGROUP_FORCE_REGISTER