NAMD
ComputeGroupRes2GroupCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeGroupRes2GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 /*! Compute restraint force, virial, and energy applied to large
22  group 2 (atoms >= 1024), due to restraining COM of group 2
23  (h_group2COM) to the COM of the group 1 (h_group1COM).
24  To use this function, the COM of the group 1 and 2
25  must be calculated and passed to this function as h_group1COM
26  and h_group2COM.
27  This function also calculates the distance from COM of the
28  group 1 to COM of the group 2. */
29 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
30 __global__ void computeLargeGroupRestraint2GroupsKernel(
31  const int numRestrainedGroup1,
32  const int totalNumRestrained,
33  const int restraintExp,
34  const double restraintK,
35  const double3 resCenterVec,
36  const double3 resDirection,
37  const double inv_group1_mass,
38  const double inv_group2_mass,
39  const int* __restrict groupAtomsSOAIndex,
40  const Lattice lat,
41  const char3* __restrict transform,
42  const float* __restrict mass,
43  const double* __restrict pos_x,
44  const double* __restrict pos_y,
45  const double* __restrict pos_z,
46  double* __restrict f_normal_x,
47  double* __restrict f_normal_y,
48  double* __restrict f_normal_z,
49  cudaTensor* __restrict d_virial,
50  cudaTensor* __restrict h_extVirial,
51  double* __restrict h_resEnergy,
52  double3* __restrict h_resForce,
53  double3* __restrict group1COM, // on device if mgpu
54  double3* __restrict group2COM, // on device if mgpu
55  double3* __restrict h_diffCOM,
56  unsigned int* __restrict d_tbcatomic)
57 {
58  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
59  int totaltb = gridDim.x;
60  bool isLastBlockDone = false;
61 
62  int SOAindex;
63  double m = 0;
64  double energy = 0.0;
65  double3 diffCOM = {0, 0, 0};
66  double3 group_f = {0, 0, 0};
67  double3 pos = {0, 0, 0};
68  double3 f = {0, 0, 0};
69  cudaTensor r_virial;
70  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
71  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
72  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
73 
74  if(tIdx < totalNumRestrained) {
75  SOAindex = groupAtomsSOAIndex[tIdx];
76  m = mass[SOAindex];
77  // Here for consistency with ComputeGroupRes1, we calculate
78  // distance from com1 to com2 along specific restraint dimention,
79  // so force is acting on group 2
80  if(T_MGPUON){
81  diffCOM.x = (group2COM->x*inv_group2_mass - group1COM->x*inv_group1_mass) * resDirection.x;
82  diffCOM.y = (group2COM->y*inv_group2_mass - group1COM->y*inv_group1_mass) * resDirection.y;
83  diffCOM.z = (group2COM->z*inv_group2_mass - group1COM->z*inv_group1_mass) * resDirection.z;
84  }
85  else
86  {
87  diffCOM.x = (group2COM->x - group1COM->x) * resDirection.x;
88  diffCOM.y = (group2COM->y - group1COM->y) * resDirection.y;
89  diffCOM.z = (group2COM->z - group1COM->z) * resDirection.z;
90  }
91  // Calculate the minimum image distance
92  diffCOM = lat.delta_from_diff(diffCOM);
93 
94  if (T_USEMAGNITUDE) {
95  // Calculate the difference from equilibrium restraint distance
96  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
97  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
98  resCenterVec.z*resCenterVec.z);
99 
100  double distDiff = (comVal - centerVal);
101  double distSqDiff = distDiff * distDiff;
102  double invCOMVal = 1.0 / comVal;
103 
104  // Calculate energy and force on group of atoms
105  if(distSqDiff > 0.0f) { // To avoid numerical error
106  // Energy = k * (r - r_eq)^n
107  energy = restraintK * distSqDiff;
108  for (int n = 2; n < restraintExp; n += 2) {
109  energy *= distSqDiff;
110  }
111  // Force = -k * n * (r - r_eq)^(n-1)
112  double force = -energy * restraintExp / distDiff;
113  // calculate force along COM difference
114  group_f.x = force * diffCOM.x * invCOMVal;
115  group_f.y = force * diffCOM.y * invCOMVal;
116  group_f.z = force * diffCOM.z * invCOMVal;
117  }
118  } else {
119  // Calculate the difference from equilibrium restraint distance vector
120  // along specific restraint dimention
121  double3 resDist;
122  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
123  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
124  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
125  // Wrap the distance difference (diffCOM - resCenterVec)
126  resDist = lat.delta_from_diff(resDist);
127 
128  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
129 
130  // Calculate energy and force on group of atoms
131  if(distSqDiff > 0.0f) { // To avoid numerical error
132  // Energy = k * (r - r_eq)^n
133  energy = restraintK * distSqDiff;
134  for (int n = 2; n < restraintExp; n += 2) {
135  energy *= distSqDiff;
136  }
137  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
138  double force = -energy * restraintExp / distSqDiff;
139  group_f.x = force * resDist.x;
140  group_f.y = force * resDist.y;
141  group_f.z = force * resDist.z;
142  }
143  }
144 
145  // calculate the force on each atom of the group
146  if (tIdx < numRestrainedGroup1) {
147  // threads [0 , numGroup1Atoms) calculate force for group 1
148  // We use negative because force is calculated for group 2
149  f.x = -group_f.x * m * inv_group1_mass;
150  f.y = -group_f.y * m * inv_group1_mass;
151  f.z = -group_f.z * m * inv_group1_mass;
152  } else {
153  // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
154  f.x = group_f.x * m * inv_group2_mass;
155  f.y = group_f.y * m * inv_group2_mass;
156  f.z = group_f.z * m * inv_group2_mass;
157  }
158  // apply the bias to each atom in group
159  f_normal_x[SOAindex] += f.x;
160  f_normal_y[SOAindex] += f.y;
161  f_normal_z[SOAindex] += f.z;
162  // Virial is based on applied force on each atom
163  if(T_DOVIRIAL) {
164  // positions must be unwraped for virial calculation
165  pos.x = pos_x[SOAindex];
166  pos.y = pos_y[SOAindex];
167  pos.z = pos_z[SOAindex];
168  char3 tr = transform[SOAindex];
169  pos = lat.reverse_transform(pos, tr);
170  r_virial.xx = f.x * pos.x;
171  r_virial.xy = f.x * pos.y;
172  r_virial.xz = f.x * pos.z;
173  r_virial.yx = f.y * pos.x;
174  r_virial.yy = f.y * pos.y;
175  r_virial.yz = f.y * pos.z;
176  r_virial.zx = f.z * pos.x;
177  r_virial.zy = f.z * pos.y;
178  r_virial.zz = f.z * pos.z;
179  }
180  }
181  __syncthreads();
182 
183  if(T_DOENERGY || T_DOVIRIAL) {
184  if(T_DOVIRIAL) {
185  // Reduce virial values in the thread block
186  typedef cub::BlockReduce<double, 128> BlockReduce;
187  __shared__ typename BlockReduce::TempStorage temp_storage;
188 
189  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
190  __syncthreads();
191  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
192  __syncthreads();
193  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
194  __syncthreads();
195 
196  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
197  __syncthreads();
198  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
199  __syncthreads();
200  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
201  __syncthreads();
202 
203  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
204  __syncthreads();
205  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
206  __syncthreads();
207  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
208  __syncthreads();
209  }
210 
211  if(threadIdx.x == 0) {
212  if(T_DOVIRIAL) {
213  // thread 0 adds the reduced virial values into device memory
214  atomicAdd(&(d_virial->xx), r_virial.xx);
215  atomicAdd(&(d_virial->xy), r_virial.xy);
216  atomicAdd(&(d_virial->xz), r_virial.xz);
217 
218  atomicAdd(&(d_virial->yx), r_virial.yx);
219  atomicAdd(&(d_virial->yy), r_virial.yy);
220  atomicAdd(&(d_virial->yz), r_virial.yz);
221 
222  atomicAdd(&(d_virial->zx), r_virial.zx);
223  atomicAdd(&(d_virial->zy), r_virial.zy);
224  atomicAdd(&(d_virial->zz), r_virial.zz);
225  }
226  __threadfence();
227  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
228  isLastBlockDone = (value == (totaltb -1));
229  }
230 
231  __syncthreads();
232 
233  if(isLastBlockDone) {
234  // Thread 0 of the last block will set the host values
235  if(threadIdx.x == 0) {
236  if(T_DOENERGY) {
237  h_resEnergy[0] = energy; // restraint energy for each group, needed for output
238  h_diffCOM->x = diffCOM.x; // distance between COM of two restrained groups
239  h_diffCOM->y = diffCOM.y; // distance between COM of two restrained groups
240  h_diffCOM->z = diffCOM.z; // distance between COM of two restrained groups
241  h_resForce->x = group_f.x; // restraint force on group 2
242  h_resForce->y = group_f.y; // restraint force on group 2
243  h_resForce->z = group_f.z; // restraint force on group 2
244  }
245  if(T_DOVIRIAL) {
246  // Add virial values to host memory.
247  // We use add,since we have with multiple restraints group
248  h_extVirial->xx += d_virial->xx;
249  h_extVirial->xy += d_virial->xy;
250  h_extVirial->xz += d_virial->xz;
251  h_extVirial->yx += d_virial->yx;
252  h_extVirial->yy += d_virial->yy;
253  h_extVirial->yz += d_virial->yz;
254  h_extVirial->zx += d_virial->zx;
255  h_extVirial->zy += d_virial->zy;
256  h_extVirial->zz += d_virial->zz;
257 
258  //reset the device virial value
259  d_virial->xx = 0;
260  d_virial->xy = 0;
261  d_virial->xz = 0;
262 
263  d_virial->yx = 0;
264  d_virial->yy = 0;
265  d_virial->yz = 0;
266 
267  d_virial->zx = 0;
268  d_virial->zy = 0;
269  d_virial->zz = 0;
270  }
271  //resets atomic counter
272  d_tbcatomic[0] = 0;
273  __threadfence();
274  }
275  }
276  }
277  else if(T_MGPUON)
278  {// need lastBlockDone for MGPU
279  if(threadIdx.x == 0){
280  __threadfence();
281  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
282  isLastBlockDone = (value == (totaltb -1));
283  }
284  }
285  __syncthreads();
286  if(T_MGPUON){
287  if(isLastBlockDone){
288  if(threadIdx.x == 0){
289  // zero out for next iteration
290  group1COM->x = 0.0;
291  group1COM->y = 0.0;
292  group1COM->z = 0.0;
293  group2COM->x = 0.0;
294  group2COM->y = 0.0;
295  group2COM->z = 0.0;
296  //resets atomic counter
297  d_tbcatomic[0] = 0;
298  __threadfence();
299  }
300  }
301  }
302 }
303 
304 
305 /*! Compute restraint force, virial, and energy applied to small
306  groups (atoms < 1024), due to restraining COM of group 2
307  (h_group2COM) to the COM of the group 1 (h_group1COM).
308  This function also calculates the distance from COM of the
309  group 1 to COM of the group 2. */
310 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
311 __global__ void computeSmallGroupRestraint2GroupsKernel(
312  const int numRestrainedGroup1,
313  const int totalNumRestrained,
314  const int restraintExp,
315  const double restraintK,
316  const double3 resCenterVec,
317  const double3 resDirection,
318  const double inv_group1_mass,
319  const double inv_group2_mass,
320  const int* __restrict groupAtomsSOAIndex,
321  const Lattice lat,
322  const char3* __restrict transform,
323  const float* __restrict mass,
324  const double* __restrict pos_x,
325  const double* __restrict pos_y,
326  const double* __restrict pos_z,
327  double* __restrict f_normal_x,
328  double* __restrict f_normal_y,
329  double* __restrict f_normal_z,
330  cudaTensor* __restrict h_extVirial,
331  double* __restrict h_resEnergy,
332  double3* __restrict h_resForce,
333  double3* __restrict h_diffCOM,
334  double3* __restrict group1COM, // on device in Multi GPU
335  double3* __restrict group2COM,
336  unsigned int* __restrict d_tbcatomic)
337 {
338  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
339  __shared__ double3 sh_com1;
340  __shared__ double3 sh_com2;
341  bool isLastBlockDone = false;
342  int totaltb = gridDim.x;
343  double m = 0;
344  double energy = 0.0;
345  double3 com1 = {0, 0, 0};
346  double3 com2 = {0, 0, 0};
347  double3 diffCOM = {0, 0, 0};
348  double3 group_f = {0, 0, 0};
349  double3 pos = {0, 0, 0};
350  double3 f = {0, 0, 0};
351  cudaTensor r_virial;
352  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
353  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
354  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
355  int SOAindex;
356  if(T_MGPUON)
357  {
358  com1.x = group1COM->x;
359  com1.y = group1COM->y;
360  com1.z = group1COM->z;
361  com2.x = group2COM->x;
362  com2.y = group2COM->y;
363  com2.z = group2COM->z;
364  }
365  if(tIdx < totalNumRestrained){
366  // First -> recalculate center of mass.
367  SOAindex = groupAtomsSOAIndex[tIdx];
368 
369  m = mass[SOAindex]; // Cast from float to double here
370  pos.x = pos_x[SOAindex];
371  pos.y = pos_y[SOAindex];
372  pos.z = pos_z[SOAindex];
373 
374  // unwrap the coordinate to calculate COM
375  char3 tr = transform[SOAindex];
376  pos = lat.reverse_transform(pos, tr);
377  if(!T_MGPUON){
378  if (tIdx < numRestrainedGroup1) {
379  // we initialized the com2 to zero
380  com1.x = pos.x * m;
381  com1.y = pos.y * m;
382  com1.z = pos.z * m;
383  } else {
384  // we initialized the com1 to zero
385  com2.x = pos.x * m;
386  com2.y = pos.y * m;
387  com2.z = pos.z * m;
388  }
389  }
390  }
391 
392  // reduce the (mass * position) values for group 1 and 2 in the thread block
393  typedef cub::BlockReduce<double, 1024> BlockReduce;
394  __shared__ typename BlockReduce::TempStorage temp_storage;
395  if(!T_MGPUON){
396  com1.x = BlockReduce(temp_storage).Sum(com1.x);
397  __syncthreads();
398  com1.y = BlockReduce(temp_storage).Sum(com1.y);
399  __syncthreads();
400  com1.z = BlockReduce(temp_storage).Sum(com1.z);
401  __syncthreads();
402  com2.x = BlockReduce(temp_storage).Sum(com2.x);
403  __syncthreads();
404  com2.y = BlockReduce(temp_storage).Sum(com2.y);
405  __syncthreads();
406  com2.z = BlockReduce(temp_storage).Sum(com2.z);
407  __syncthreads();
408  }
409  // Thread 0 calculates the COM of group 1 and 2
410  if(threadIdx.x == 0){
411  sh_com1.x = com1.x * inv_group1_mass; // calculates the COM of group 1
412  sh_com1.y = com1.y * inv_group1_mass; // calculates the COM of group 1
413  sh_com1.z = com1.z * inv_group1_mass; // calculates the COM of group 1
414  sh_com2.x = com2.x * inv_group2_mass; // calculates the COM of group 2
415  sh_com2.y = com2.y * inv_group2_mass; // calculates the COM of group 2
416  sh_com2.z = com2.z * inv_group2_mass; // calculates the COM of group 2
417  }
418  __syncthreads();
419 
420  if(tIdx < totalNumRestrained) {
421  // Here for consistency with distanceZ, we calculate
422  // distance from com1 to com2 along specific restraint dimention,
423  // so force is acting on group 2
424  diffCOM.x = (sh_com2.x - sh_com1.x) * resDirection.x;
425  diffCOM.y = (sh_com2.y - sh_com1.y) * resDirection.y;
426  diffCOM.z = (sh_com2.z - sh_com1.z) * resDirection.z;
427  // Calculate the minimum image distance
428  diffCOM = lat.delta_from_diff(diffCOM);
429 
430  if (T_USEMAGNITUDE) {
431  // Calculate the difference from equilibrium restraint distance
432  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
433  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
434  resCenterVec.z*resCenterVec.z);
435 
436  double distDiff = (comVal - centerVal);
437  double distSqDiff = distDiff * distDiff;
438  double invCOMVal = 1.0 / comVal;
439 
440  // Calculate energy and force on group of atoms
441  if(distSqDiff > 0.0f) { // To avoid numerical error
442  // Energy = k * (r - r_eq)^n
443  energy = restraintK * distSqDiff;
444  for (int n = 2; n < restraintExp; n += 2) {
445  energy *= distSqDiff;
446  }
447  // Force = -k * n * (r - r_eq)^(n-1)
448  double force = -energy * restraintExp / distDiff;
449  // calculate force along COM difference
450  group_f.x = force * diffCOM.x * invCOMVal;
451  group_f.y = force * diffCOM.y * invCOMVal;
452  group_f.z = force * diffCOM.z * invCOMVal;
453  }
454  } else {
455  // Calculate the difference from equilibrium restraint distance vector
456  // along specific restraint dimention
457  double3 resDist;
458  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
459  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
460  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
461  // Wrap the distance difference (diffCOM - resCenterVec)
462  resDist = lat.delta_from_diff(resDist);
463 
464  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
465 
466  // Calculate energy and force on group of atoms
467  if(distSqDiff > 0.0f) { // To avoid numerical error
468  // Energy = k * (r - r_eq)^n
469  energy = restraintK * distSqDiff;
470  for (int n = 2; n < restraintExp; n += 2) {
471  energy *= distSqDiff;
472  }
473  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
474  double force = -energy * restraintExp / distSqDiff;
475  group_f.x = force * resDist.x;
476  group_f.y = force * resDist.y;
477  group_f.z = force * resDist.z;
478  }
479  }
480 
481  // calculate the force on each atom of the group
482  if (tIdx < numRestrainedGroup1) {
483  // threads [0 , numGroup1Atoms) calculate force for group 1
484  // We use negative because force is calculated for group 2
485  f.x = -group_f.x * m * inv_group1_mass;
486  f.y = -group_f.y * m * inv_group1_mass;
487  f.z = -group_f.z * m * inv_group1_mass;
488  } else {
489  // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
490  f.x = group_f.x * m * inv_group2_mass;
491  f.y = group_f.y * m * inv_group2_mass;
492  f.z = group_f.z * m * inv_group2_mass;
493  }
494 
495  // apply the bias to each atom in group
496  f_normal_x[SOAindex] += f.x ;
497  f_normal_y[SOAindex] += f.y ;
498  f_normal_z[SOAindex] += f.z ;
499  // Virial is based on applied force on each atom
500  if(T_DOVIRIAL){
501  // positions must be unwraped for virial calculation
502  r_virial.xx = f.x * pos.x;
503  r_virial.xy = f.x * pos.y;
504  r_virial.xz = f.x * pos.z;
505  r_virial.yx = f.y * pos.x;
506  r_virial.yy = f.y * pos.y;
507  r_virial.yz = f.y * pos.z;
508  r_virial.zx = f.z * pos.x;
509  r_virial.zy = f.z * pos.y;
510  r_virial.zz = f.z * pos.z;
511  }
512  }
513  __syncthreads();
514 
515  if(T_DOENERGY || T_DOVIRIAL) {
516  if(T_DOVIRIAL){
517  // Reduce virial values in the thread block
518  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
519  __syncthreads();
520  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
521  __syncthreads();
522  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
523  __syncthreads();
524 
525  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
526  __syncthreads();
527  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
528  __syncthreads();
529  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
530  __syncthreads();
531 
532  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
533  __syncthreads();
534  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
535  __syncthreads();
536  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
537  __syncthreads();
538  }
539 
540  // thread zero updates the restraints energy and force
541  if(threadIdx.x == 0){
542  if(T_DOVIRIAL){
543  // Add virial values to host memory.
544  // We use add,since we have with multiple restraints group
545  h_extVirial->xx += r_virial.xx;
546  h_extVirial->xy += r_virial.xy;
547  h_extVirial->xz += r_virial.xz;
548  h_extVirial->yx += r_virial.yx;
549  h_extVirial->yy += r_virial.yy;
550  h_extVirial->yz += r_virial.yz;
551  h_extVirial->zx += r_virial.zx;
552  h_extVirial->zy += r_virial.zy;
553  h_extVirial->zz += r_virial.zz;
554  }
555  if (T_DOENERGY) {
556  h_resEnergy[0] = energy; // restraint energy for each group, needed for output
557  h_diffCOM->x = diffCOM.x; // distance between two COM of restrained groups
558  h_diffCOM->y = diffCOM.y; // distance between two COM of restrained groups
559  h_diffCOM->z = diffCOM.z; // distance between two COM of restrained groups
560  h_resForce->x = group_f.x; // restraint force on group
561  h_resForce->y = group_f.y; // restraint force on group
562  h_resForce->z = group_f.z; // restraint force on group
563  }
564  }
565  }
566  if(T_MGPUON)
567  {// need lastBockDone for T_MGPUON
568  if(threadIdx.x == 0){
569  __threadfence();
570  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
571  isLastBlockDone = (value == (totaltb -1));
572  }
573  }
574  if(T_MGPUON)
575  {
576  if(isLastBlockDone)
577  if(threadIdx.x == 0)
578  { // zero out for next iteration
579  group1COM->x = 0.0;
580  group1COM->y = 0.0;
581  group1COM->z = 0.0;
582 
583  group2COM->x = 0.0;
584  group2COM->y = 0.0;
585  group2COM->z = 0.0;
586  //resets atomic counter
587  d_tbcatomic[0] = 0;
588  __threadfence();
589  }
590  }
591 }
592 
593 /*! Compute restraint force, energy, and virial
594  applied to group 2, due to restraining COM of
595  group 2 to the COM of group 1 */
596 void computeGroupRestraint_2Group(
597  const bool mGpuOn,
598  const int useMagnitude,
599  const int doEnergy,
600  const int doVirial,
601  const int numRestrainedGroup1,
602  const int totalNumRestrained,
603  const int restraintExp,
604  const double restraintK,
605  const double3 resCenterVec,
606  const double3 resDirection,
607  const double inv_group1_mass,
608  const double inv_group2_mass,
609  const int* d_groupAtomsSOAIndex,
610  const Lattice &lat,
611  const char3* d_transform,
612  const float* d_mass,
613  const double* d_pos_x,
614  const double* d_pos_y,
615  const double* d_pos_z,
616  double* d_f_normal_x,
617  double* d_f_normal_y,
618  double* d_f_normal_z,
619  cudaTensor* d_virial,
620  cudaTensor* h_extVirial,
621  double* h_resEnergy,
622  double3* h_resForce,
623  double3* h_group1COM,
624  double3* h_group2COM,
625  double3* h_diffCOM,
626  double3* d_group1COM,
627  double3* d_group2COM,
628  double3** d_peer1COM,
629  double3** d_peer2COM,
630  unsigned int* d_tbcatomic,
631  const int numDevices,
632  cudaStream_t stream)
633 {
634  int options = doEnergy + (doVirial << 1) + (useMagnitude << 2)
635  + (mGpuOn << 3);
636  double3* COM1Ptr=(mGpuOn) ? d_group1COM: h_group1COM;
637  double3* COM2Ptr=(mGpuOn) ? d_group2COM: h_group2COM;
638  const int blocks = (totalNumRestrained > 1024) ? 128 : 1024;
639  const int grid = (totalNumRestrained > 1024) ? (totalNumRestrained + blocks - 1) / blocks : 1;
640  if (totalNumRestrained > 1024) {
641  // first calculate the COM for restraint groups and store it in
642  // h_group1COM and h_group2COM
643  if(!mGpuOn) // if we don't have distributed COM
644  {
645  compute2COMKernel<128><<<grid, blocks, 0, stream>>>(
646  numRestrainedGroup1,
647  totalNumRestrained,
648  inv_group1_mass,
649  inv_group2_mass,
650  lat,
651  d_mass,
652  d_pos_x,
653  d_pos_y,
654  d_pos_z,
655  d_transform,
656  d_groupAtomsSOAIndex,
657  d_group1COM,
658  d_group2COM,
659  h_group1COM,
660  h_group2COM,
661  d_tbcatomic);
662  }
663  else{
664  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
665  d_group1COM,
666  numDevices);
667  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
668  d_group2COM,
669  numDevices);
670  }
671 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
672  computeLargeGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
673  <<<grid, blocks, 0, stream>>>( \
674  numRestrainedGroup1, totalNumRestrained, \
675  restraintExp, restraintK, resCenterVec, resDirection, \
676  inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
677  lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
678  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
679  h_extVirial, h_resEnergy, h_resForce, COM1Ptr, \
680  COM2Ptr, h_diffCOM, d_tbcatomic);
681  switch(options) {
682  case 0: CALL_LARGE_GROUP_RES(0, 0, 0, 0); break;
683  case 1: CALL_LARGE_GROUP_RES(1, 0, 0, 0); break;
684  case 2: CALL_LARGE_GROUP_RES(0, 1, 0, 0); break;
685  case 3: CALL_LARGE_GROUP_RES(1, 1, 0, 0); break;
686  case 4: CALL_LARGE_GROUP_RES(0, 0, 1, 0); break;
687  case 5: CALL_LARGE_GROUP_RES(1, 0, 1, 0); break;
688  case 6: CALL_LARGE_GROUP_RES(0, 1, 1, 0); break;
689  case 7: CALL_LARGE_GROUP_RES(1, 1, 1, 0); break;
690  case 8: CALL_LARGE_GROUP_RES(0, 0, 0, 1); break;
691  case 9: CALL_LARGE_GROUP_RES(1, 0, 0, 1); break;
692  case 10: CALL_LARGE_GROUP_RES(0, 1, 0, 1); break;
693  case 11: CALL_LARGE_GROUP_RES(1, 1, 0, 1); break;
694  case 12: CALL_LARGE_GROUP_RES(0, 0, 1, 1); break;
695  case 13: CALL_LARGE_GROUP_RES(1, 0, 1, 1); break;
696  case 14: CALL_LARGE_GROUP_RES(0, 1, 1, 1); break;
697  case 15: CALL_LARGE_GROUP_RES(1, 1, 1, 1); break;
698  }
699 #undef CALL_LARGE_GROUP_RES
700  } else {
701  // For small group of restrained atom, we can just launch
702  // a single threadblock
703  if(mGpuOn){
704  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
705  d_group1COM,
706  numDevices);
707  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
708  d_group2COM,
709  numDevices);
710  }
711 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
712  computeSmallGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
713  <<<grid, blocks, 0, stream>>>( \
714  numRestrainedGroup1, totalNumRestrained, \
715  restraintExp, restraintK, resCenterVec, resDirection, \
716  inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
717  lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
718  d_f_normal_x, d_f_normal_y, d_f_normal_z, \
719  h_extVirial, h_resEnergy, h_resForce, h_diffCOM, \
720  COM1Ptr, COM2Ptr, d_tbcatomic \
721 );
722  switch(options) {
723  case 0: CALL_SMALL_GROUP_RES(0, 0, 0, 0); break;
724  case 1: CALL_SMALL_GROUP_RES(1, 0, 0, 0); break;
725  case 2: CALL_SMALL_GROUP_RES(0, 1, 0, 0); break;
726  case 3: CALL_SMALL_GROUP_RES(1, 1, 0, 0); break;
727  case 4: CALL_SMALL_GROUP_RES(0, 0, 1, 0); break;
728  case 5: CALL_SMALL_GROUP_RES(1, 0, 1, 0); break;
729  case 6: CALL_SMALL_GROUP_RES(0, 1, 1, 0); break;
730  case 7: CALL_SMALL_GROUP_RES(1, 1, 1, 0); break;
731  case 8: CALL_SMALL_GROUP_RES(0, 0, 0, 1); break;
732  case 9: CALL_SMALL_GROUP_RES(1, 0, 0, 1); break;
733  case 10: CALL_SMALL_GROUP_RES(0, 1, 0, 1); break;
734  case 11: CALL_SMALL_GROUP_RES(1, 1, 0, 1); break;
735  case 12: CALL_SMALL_GROUP_RES(0, 0, 1, 1); break;
736  case 13: CALL_SMALL_GROUP_RES(1, 0, 1, 1); break;
737  case 14: CALL_SMALL_GROUP_RES(0, 1, 1, 1); break;
738  case 15: CALL_SMALL_GROUP_RES(1, 1, 1, 1); break;
739  }
740 #undef CALL_SMALL_GROUP_RES
741  }
742 
743 }
744 
745 #endif // NODEGROUP_FORCE_REGISTER