NAMD
ComputeSMDCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeSMDCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 
22 /*! Calculate SMD force and virial for large atom group (numSMDAtoms > 1024)
23  Multiple thread block will be called to do this operation.
24  The current COM (curCOM) must be calculated and pssed to this function. */
25 template<bool T_DOENERGY, bool T_MGPUON>
26 __global__ void computeSMDForceWithCOMKernel(
27  const int numSMDAtoms,
28  const Lattice lat,
29  const double inv_group_mass,
30  const double k,
31  const double k2,
32  const double velocity,
33  const double3 direction,
34  const int currentTime,
35  const double3 origCM,
36  const float * __restrict mass,
37  const double* __restrict pos_x,
38  const double* __restrict pos_y,
39  const double* __restrict pos_z,
40  const char3* __restrict transform,
41  double* __restrict f_normal_x,
42  double* __restrict f_normal_y,
43  double* __restrict f_normal_z,
44  const int* __restrict smdAtomsSOAIndex,
45  cudaTensor* __restrict d_virial,
46  double3* __restrict h_curCOM,
47  double3* __restrict d_curCOM,
48  double3** __restrict d_peerCOM,
49  double* __restrict h_extEnergy,
50  double3* __restrict h_extForce,
51  cudaTensor* __restrict h_extVirial,
52  unsigned int* __restrict tbcatomic)
53 {
54  int tid = threadIdx.x + blockIdx.x * blockDim.x;
55  int totaltb = gridDim.x;
56  bool isLastBlockDone = 0;
57  double3 group_f = {0, 0, 0};
58  double energy = 0.0;
59  double3 pos = {0, 0, 0};
60  double3 f = {0, 0, 0};
61  cudaTensor r_virial;
62  double3 cm={h_curCOM->x, h_curCOM->y, h_curCOM->z};
63  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
64  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
65  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
66  int SOAindex;
67 
68  if(T_MGPUON)
69  {
70  cm.x = d_curCOM->x * inv_group_mass;
71  cm.y = d_curCOM->y * inv_group_mass;
72  cm.z = d_curCOM->z * inv_group_mass;
73  }
74 
75  if(tid < numSMDAtoms){
76  SOAindex = smdAtomsSOAIndex[tid];
77 
78  // uncoalesced memory access: too bad!
79  double m = mass[SOAindex]; // Cast from float to double here
80  pos.x = pos_x[SOAindex];
81  pos.y = pos_y[SOAindex];
82  pos.z = pos_z[SOAindex];
83 
84  // calculate the distance difference along direction
85  double3 diffCOM;
86  diffCOM.x = cm.x - origCM.x;
87  diffCOM.y = cm.y - origCM.y;
88  diffCOM.z = cm.z - origCM.z;
89  double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
90  diffCOM.z*direction.z;
91 
92  // Ok so we've calculated the new center of mass, now we can calculate the bias
93  double preFactor = (velocity*currentTime - diff);
94  group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
95  group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
96  group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
97 
98  // calculate the force on each atom
99  f.x = group_f.x * m * inv_group_mass;
100  f.y = group_f.y * m * inv_group_mass;
101  f.z = group_f.z * m * inv_group_mass;
102 
103  // apply the bias
104  f_normal_x[SOAindex] += f.x ;
105  f_normal_y[SOAindex] += f.y ;
106  f_normal_z[SOAindex] += f.z ;
107  if(T_DOENERGY){
108  // energy for restraint along the direction
109  energy = 0.5*k*preFactor*preFactor;
110  // energy for transverse restraint
111  energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
112  diffCOM.z*diffCOM.z - diff*diff);
113  // unwrap coordinates before calculating the virial
114  char3 t = transform[SOAindex];
115  pos = lat.reverse_transform(pos, t);
116  r_virial.xx = f.x * pos.x;
117  r_virial.xy = f.x * pos.y;
118  r_virial.xz = f.x * pos.z;
119  r_virial.yx = f.y * pos.x;
120  r_virial.yy = f.y * pos.y;
121  r_virial.yz = f.y * pos.z;
122  r_virial.zx = f.z * pos.x;
123  r_virial.zy = f.z * pos.y;
124  r_virial.zz = f.z * pos.z;
125  }
126  }
127  __syncthreads();
128 
129  if(T_DOENERGY){
130  typedef cub::BlockReduce<double, 128> BlockReduce;
131  __shared__ typename BlockReduce::TempStorage temp_storage;
132 
133  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
134  __syncthreads();
135  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
136  __syncthreads();
137  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
138  __syncthreads();
139 
140  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
141  __syncthreads();
142  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
143  __syncthreads();
144  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
145  __syncthreads();
146 
147  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
148  __syncthreads();
149  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
150  __syncthreads();
151  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
152  __syncthreads();
153 
154  if(threadIdx.x == 0){
155  atomicAdd(&(d_virial->xx), r_virial.xx);
156  atomicAdd(&(d_virial->xy), r_virial.xy);
157  atomicAdd(&(d_virial->xz), r_virial.xz);
158 
159  atomicAdd(&(d_virial->yx), r_virial.yx);
160  atomicAdd(&(d_virial->yy), r_virial.yy);
161  atomicAdd(&(d_virial->yz), r_virial.yz);
162 
163  atomicAdd(&(d_virial->zx), r_virial.zx);
164  atomicAdd(&(d_virial->zy), r_virial.zy);
165  atomicAdd(&(d_virial->zz), r_virial.zz);
166 
167  __threadfence();
168  unsigned int value = atomicInc(&tbcatomic[0], totaltb);
169  isLastBlockDone = (value == (totaltb -1));
170  }
171 
172  __syncthreads();
173  // Last block will set the host values
174  if(isLastBlockDone){
175  if(threadIdx.x == 0){
176  h_extEnergy[0] = energy;
177  h_extForce->x = group_f.x;
178  h_extForce->y = group_f.y;
179  h_extForce->z = group_f.z;
180 
181  h_extVirial->xx = d_virial->xx;
182  h_extVirial->xy = d_virial->xy;
183  h_extVirial->xz = d_virial->xz;
184  h_extVirial->yx = d_virial->yx;
185  h_extVirial->yy = d_virial->yy;
186  h_extVirial->yz = d_virial->yz;
187  h_extVirial->zx = d_virial->zx;
188  h_extVirial->zy = d_virial->zy;
189  h_extVirial->zz = d_virial->zz;
190  //reset the device virial value
191  d_virial->xx = 0;
192  d_virial->xy = 0;
193  d_virial->xz = 0;
194 
195  d_virial->yx = 0;
196  d_virial->yy = 0;
197  d_virial->yz = 0;
198 
199  d_virial->zx = 0;
200  d_virial->zy = 0;
201  d_virial->zz = 0;
202  }
203  }
204  }
205  else
206  { // compute isLastBlockDone in the non energy steps
207  if(threadIdx.x == 0){
208  __threadfence();
209  unsigned int value = atomicInc(&tbcatomic[0], totaltb);
210  isLastBlockDone = (value == (totaltb -1));
211  }
212  }
213  __syncthreads();
214  if(isLastBlockDone){
215  if(threadIdx.x == 0){
216  if(T_MGPUON){
217  h_curCOM->x = cm.x;
218  h_curCOM->y = cm.y;
219  h_curCOM->z = cm.z;
220  d_curCOM->x = 0.0;
221  d_curCOM->y = 0.0;
222  d_curCOM->z = 0.0;
223  }
224  //resets atomic counter
225  tbcatomic[0] = 0;
226  __threadfence();
227  }
228  }
229 }
230 
231 
232 /*! Calculate SMD force, virial, and COM for small atom group (numSMDAtoms <= 1024)
233  Single thread block will be called to do this operation.
234  The current COM will be calculated and stored in h_curCM. */
235 template<bool T_DOENERGY, bool T_MGPUON>
236 __global__ void computeSMDForceKernel(
237  const int numSMDAtoms,
238  const Lattice lat,
239  const double inv_group_mass,
240  const double k,
241  const double k2,
242  const double velocity,
243  const double3 direction,
244  const int currentTime,
245  const double3 origCM,
246  const float * __restrict mass,
247  const double* __restrict pos_x,
248  const double* __restrict pos_y,
249  const double* __restrict pos_z,
250  const char3* __restrict transform,
251  double* __restrict f_normal_x,
252  double* __restrict f_normal_y,
253  double* __restrict f_normal_z,
254  const int* __restrict smdAtomsSOAIndex,
255  double3* __restrict h_curCM,
256  double3* __restrict d_curCM,
257  double3** __restrict d_peerCOM,
258  double* __restrict h_extEnergy,
259  double3* __restrict h_extForce,
260  cudaTensor* __restrict h_extVirial,
261  unsigned int* __restrict tbcatomic)
262 {
263  __shared__ double3 group_f;
264  __shared__ double energy;
265  int tid = threadIdx.x + blockIdx.x * blockDim.x;
266  int totaltb = gridDim.x;
267  bool isLastBlockDone = 0;
268  double m = 0;
269  double3 cm = {0, 0, 0};
270  double3 pos = {0, 0, 0};
271  double3 f = {0, 0, 0};
272  cudaTensor r_virial;
273  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
274  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
275  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
276  int SOAindex;
277  // in the mGpuOn case the COM must be calculated across devices and passed in
278  if(tid < numSMDAtoms){
279  // First -> recalculate center of mass.
280  // Only thread zero is doing this
281  SOAindex = smdAtomsSOAIndex[tid];
282  m = mass[SOAindex]; // Cast from float to double here
283  pos.x = pos_x[SOAindex];
284  pos.y = pos_y[SOAindex];
285  pos.z = pos_z[SOAindex];
286 
287  // unwrap the coordinate to calculate COM
288  char3 t = transform[SOAindex];
289  pos = lat.reverse_transform(pos, t);
290  if(!T_MGPUON)
291  {
292  cm.x = pos.x * m;
293  cm.y = pos.y * m;
294  cm.z = pos.z * m;
295  }
296  }
297  // now reduce the values and add it to thread zero
298  typedef cub::BlockReduce<double, 1024> BlockReduce;
299  __shared__ typename BlockReduce::TempStorage temp_storage;
300  if(!T_MGPUON){
301  cm.x = BlockReduce(temp_storage).Sum(cm.x);
302  __syncthreads();
303  cm.y = BlockReduce(temp_storage).Sum(cm.y);
304  __syncthreads();
305  cm.z = BlockReduce(temp_storage).Sum(cm.z);
306  __syncthreads();
307  }
308  // Calculate group force and acceleration
309  if(threadIdx.x == 0){
310  if(T_MGPUON)
311  {
312  cm.x = d_curCM->x * inv_group_mass;
313  cm.y = d_curCM->y * inv_group_mass;
314  cm.z = d_curCM->z * inv_group_mass;
315  }
316  else
317  {
318  cm.x *= inv_group_mass; // calculates the current center of mass
319  cm.y *= inv_group_mass; // calculates the current center of mass
320  cm.z *= inv_group_mass; // calculates the current center of mass
321  }
322  // calculate the distance difference along direction
323  double3 diffCOM;
324  diffCOM.x = cm.x - origCM.x;
325  diffCOM.y = cm.y - origCM.y;
326  diffCOM.z = cm.z - origCM.z;
327  double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
328  diffCOM.z*direction.z;
329 
330  // Ok so we've calculated the new center of mass, now we can calculate the bias
331  double preFactor = (velocity*currentTime - diff);
332  group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
333  group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
334  group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
335  if(T_DOENERGY) {
336  // energy for restraint along the direction
337  energy = 0.5*k*preFactor*preFactor;
338  // energy for transverse restraint
339  energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
340  diffCOM.z*diffCOM.z - diff*diff);
341  }
342  }
343  __syncthreads();
344 
345  if(tid < numSMDAtoms){
346  // calculate the force on each atom
347  f.x = group_f.x * m * inv_group_mass;
348  f.y = group_f.y * m * inv_group_mass;
349  f.z = group_f.z * m * inv_group_mass;
350 
351  // apply the bias
352  f_normal_x[SOAindex] += f.x ;
353  f_normal_y[SOAindex] += f.y ;
354  f_normal_z[SOAindex] += f.z ;
355  if(T_DOENERGY){
356  r_virial.xx = f.x * pos.x;
357  r_virial.xy = f.x * pos.y;
358  r_virial.xz = f.x * pos.z;
359  r_virial.yx = f.y * pos.x;
360  r_virial.yy = f.y * pos.y;
361  r_virial.yz = f.y * pos.z;
362  r_virial.zx = f.z * pos.x;
363  r_virial.zy = f.z * pos.y;
364  r_virial.zz = f.z * pos.z;
365  }
366  }
367  if(T_MGPUON){
368  if(threadIdx.x == 0){
369  __threadfence();
370  unsigned int value = atomicInc(&tbcatomic[0], totaltb);
371  isLastBlockDone = (value == (totaltb -1));
372  }
373  }
374  if(T_DOENERGY){
375  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
376  __syncthreads();
377  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
378  __syncthreads();
379  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
380  __syncthreads();
381 
382  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
383  __syncthreads();
384  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
385  __syncthreads();
386  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
387  __syncthreads();
388 
389  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
390  __syncthreads();
391  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
392  __syncthreads();
393  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
394  __syncthreads();
395 
396  if(threadIdx.x == 0){
397  // thread zero updates the value
398  h_curCM->x = cm.x; // update current center of mass
399  h_curCM->y = cm.y; // update current center of mass
400  h_curCM->z = cm.z; // update current center of mass
401  h_extEnergy[0] = energy; // bias energy
402  h_extForce->x = group_f.x; // bias force
403  h_extForce->y = group_f.y;
404  h_extForce->z = group_f.z;
405 
406  h_extVirial->xx = r_virial.xx;
407  h_extVirial->xy = r_virial.xy;
408  h_extVirial->xz = r_virial.xz;
409  h_extVirial->yx = r_virial.yx;
410  h_extVirial->yy = r_virial.yy;
411  h_extVirial->yz = r_virial.yz;
412  h_extVirial->zx = r_virial.zx;
413  h_extVirial->zy = r_virial.zy;
414  h_extVirial->zz = r_virial.zz;
415  }
416  }
417  // last block cleans up
418  if(T_MGPUON) {
419  if(isLastBlockDone){
420  if(threadIdx.x == 0){
421  // zero out for next iteration
422  d_curCM->x = 0.0;
423  d_curCM->y = 0.0;
424  d_curCM->z = 0.0;
425  //resets atomic counter
426  tbcatomic[0] = 0;
427  __threadfence();
428  }
429  }
430  }
431 }
432 
433 /*! Compute SMD force and virial on group of atoms */
434 void computeSMDForce(
435  const Lattice &lat,
436  const double inv_group_mass,
437  const double spring_constant,
438  const double transverse_spring_constant,
439  const double velocity,
440  const double3 direction,
441  const int doEnergy,
442  const int currentTime,
443  const bool mGpuOn,
444  const double3 origCM,
445  const float* d_mass,
446  const double* d_pos_x,
447  const double* d_pos_y,
448  const double* d_pos_z,
449  const char3* d_transform,
450  double * d_f_normal_x,
451  double * d_f_normal_y,
452  double * d_f_normal_z,
453  const int numSMDAtoms,
454  const int* d_smdAtomsSOAIndex,
455  double3* d_curCM,
456  double3* h_curCM,
457  double3** d_peerCOM,
458  cudaTensor* d_extVirial,
459  double* h_extEnergy,
460  double3* h_extForce,
461  cudaTensor* h_extVirial,
462  unsigned int* d_tbcatomic,
463  const int numDevices,
464  const int deviceIndex,
465  cudaStream_t stream)
466 {
467 
468  const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
469  const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
470 
471 #define CALL_WITH_COM(DOENERGY, MGPUON) \
472  computeSMDForceWithCOMKernel<DOENERGY, MGPUON> \
473  <<< grid, blocks, 0 , stream >>> \
474  (numSMDAtoms, lat, inv_group_mass, spring_constant, \
475  transverse_spring_constant, velocity, direction, currentTime, \
476  origCM, d_mass, d_pos_x, d_pos_y, d_pos_z, d_transform, \
477  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_smdAtomsSOAIndex, \
478  d_extVirial, h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce, \
479  h_extVirial, d_tbcatomic);
480 
481 #define CALL(DOENERGY, MGPUON) \
482  computeSMDForceKernel<DOENERGY, MGPUON> \
483  <<<grid, blocks, 0, stream>>> \
484  (numSMDAtoms, lat, inv_group_mass, spring_constant, \
485  transverse_spring_constant, velocity, direction, currentTime, \
486  origCM, d_mass, d_pos_x, d_pos_y, d_pos_z, d_transform, \
487  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_smdAtomsSOAIndex, \
488  h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce, h_extVirial, \
489  d_tbcatomic);
490 
491  if (numSMDAtoms > 1024) {
492  if(!mGpuOn)
493  { //first calculate the COM for SMD group and store it in h_curCM
494  computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
495  numSMDAtoms,
496  inv_group_mass,
497  lat,
498  d_mass,
499  d_pos_x,
500  d_pos_y,
501  d_pos_z,
502  d_transform,
503  d_smdAtomsSOAIndex,
504  d_curCM,
505  h_curCM,
506  d_tbcatomic);
507  }
508  else
509  {// sum up the COMs across devices to this device
510  computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
511  d_curCM,
512  numDevices);
513  }
514  if(doEnergy && mGpuOn) CALL_WITH_COM(true, true);
515  if(doEnergy && !mGpuOn) CALL_WITH_COM(true, false);
516  if(!doEnergy && mGpuOn) CALL_WITH_COM(false, true);
517  if(!doEnergy && !mGpuOn) CALL_WITH_COM(false, false);
518  }
519  else
520  {
521  if(mGpuOn)
522  {// sum up the COMs across devices to this device
523  computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
524  d_curCM,
525  numDevices);
526 
527  }
528  if(doEnergy && mGpuOn) CALL(true, true);
529  if(doEnergy && !mGpuOn) CALL(true, false);
530  if(!doEnergy && mGpuOn) CALL(false, true);
531  if(!doEnergy && !mGpuOn) CALL(false, false);
532  }
533 #undef CALL_WITH_COM
534 #undef CALL
535 }
536 
537 
538 void initPeerCOMmgpu(
539  const int numDevices,
540  const int deviceIndex,
541  double3** d_peerPool,
542  double3* d_peerCOM,
543  cudaStream_t stream)
544 {
545  const int blocks = numDevices;
546  const int grid = 1;
547  initPeerCOMKernel<<<grid, blocks, 0, stream>>>( numDevices,
548  deviceIndex,
549  d_peerPool,
550  d_peerCOM);
551 }
552 
553 
554 /* called in earlier phase to handle multi device COM */
555 void computeCOMSMDMgpu(
556  const int numSMDAtoms,
557  const Lattice &lat,
558  const float* d_mass,
559  const double* d_pos_x,
560  const double* d_pos_y,
561  const double* d_pos_z,
562  const char3* d_transform,
563  const int* d_smdAtomsSOAIndex,
564  double3* d_peerCOM,
565  double3** d_peer_curCM,
566  unsigned int* d_tbcatomic,
567  const int numDevices,
568  const int deviceIndex,
569  cudaStream_t stream)
570 {
571  // block it up if large, otherwise all in one go
572  const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
573  const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
574  //initialize the device memory to zero here
575  cudaCheck(cudaMemset(d_peerCOM, 0, sizeof(double3)));
576  if(numSMDAtoms >1024)
577  computeCOMKernelMgpu<128><<<grid, blocks, 0, stream>>>(numSMDAtoms,
578  lat, d_mass,
579  d_pos_x, d_pos_y, d_pos_z,
580  d_transform,
581  d_smdAtomsSOAIndex,
582  d_peer_curCM,
583  numDevices,
584  deviceIndex,
585  d_tbcatomic);
586  else
587  computeCOMKernelMgpu<1024><<<grid, blocks, 0, stream>>>(numSMDAtoms,
588  lat, d_mass,
589  d_pos_x, d_pos_y, d_pos_z,
590  d_transform,
591  d_smdAtomsSOAIndex,
592  d_peer_curCM,
593  numDevices,
594  deviceIndex,
595  d_tbcatomic);
596 }
597 
598 #endif // NODEGROUP_FORCE_REGISTER