NAMD
ComputeGroupRes2GroupCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeGroupRes2GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 /*! Compute restraint force, virial, and energy applied to large
22  group 2 (atoms >= 1024), due to restraining COM of group 2
23  (h_group2COM) to the COM of the group 1 (h_group1COM).
24  To use this function, the COM of the group 1 and 2
25  must be calculated and passed to this function as h_group1COM
26  and h_group2COM.
27  This function also calculates the distance from COM of the
28  group 1 to COM of the group 2. */
29 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
30 __global__ void computeLargeGroupRestraint2GroupsKernel(
31  const int numRestrainedGroup1,
32  const int totalNumRestrained,
33  const int restraintExp,
34  const double restraintK,
35  const double3 resCenterVec,
36  const double3 resDirection,
37  const double inv_group1_mass,
38  const double inv_group2_mass,
39  const int* __restrict groupAtomsSOAIndex,
40  const Lattice lat,
41  const char3* __restrict transform,
42  const float* __restrict mass,
43  const double* __restrict pos_x,
44  const double* __restrict pos_y,
45  const double* __restrict pos_z,
46  double* __restrict f_normal_x,
47  double* __restrict f_normal_y,
48  double* __restrict f_normal_z,
49  cudaTensor* __restrict d_virial,
50  cudaTensor* __restrict h_extVirial,
51  double* __restrict h_resEnergy,
52  double3* __restrict h_resForce,
53  const double3* __restrict h_group1COM,
54  const double3* __restrict h_group2COM,
55  double3* __restrict h_diffCOM,
56  unsigned int* __restrict d_tbcatomic)
57 {
58  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
59  int totaltb = gridDim.x;
60  bool isLastBlockDone = false;
61 
62  int SOAindex;
63  double m = 0;
64  double energy = 0.0;
65  double3 diffCOM = {0, 0, 0};
66  double3 group_f = {0, 0, 0};
67  double3 pos = {0, 0, 0};
68  double3 f = {0, 0, 0};
69  cudaTensor r_virial;
70  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
71  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
72  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
73 
74  if(tIdx < totalNumRestrained) {
75  SOAindex = groupAtomsSOAIndex[tIdx];
76  m = mass[SOAindex];
77  // Here for consistency with ComputeGroupRes1, we calculate
78  // distance from com1 to com2 along specific restraint dimention,
79  // so force is acting on group 2
80  diffCOM.x = (h_group2COM->x - h_group1COM->x) * resDirection.x;
81  diffCOM.y = (h_group2COM->y - h_group1COM->y) * resDirection.y;
82  diffCOM.z = (h_group2COM->z - h_group1COM->z) * resDirection.z;
83  // Calculate the minimum image distance
84  diffCOM = lat.delta_from_diff(diffCOM);
85 
86  if (T_USEMAGNITUDE) {
87  // Calculate the difference from equilibrium restraint distance
88  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
89  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
90  resCenterVec.z*resCenterVec.z);
91 
92  double distDiff = (comVal - centerVal);
93  double distSqDiff = distDiff * distDiff;
94  double invCOMVal = 1.0 / comVal;
95 
96  // Calculate energy and force on group of atoms
97  if(distSqDiff > 0.0f) { // To avoid numerical error
98  // Energy = k * (r - r_eq)^n
99  energy = restraintK * distSqDiff;
100  for (int n = 2; n < restraintExp; n += 2) {
101  energy *= distSqDiff;
102  }
103  // Force = -k * n * (r - r_eq)^(n-1)
104  double force = -energy * restraintExp / distDiff;
105  // calculate force along COM difference
106  group_f.x = force * diffCOM.x * invCOMVal;
107  group_f.y = force * diffCOM.y * invCOMVal;
108  group_f.z = force * diffCOM.z * invCOMVal;
109  }
110  } else {
111  // Calculate the difference from equilibrium restraint distance vector
112  // along specific restraint dimention
113  double3 resDist;
114  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
115  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
116  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
117  // Wrap the distance difference (diffCOM - resCenterVec)
118  resDist = lat.delta_from_diff(resDist);
119 
120  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
121 
122  // Calculate energy and force on group of atoms
123  if(distSqDiff > 0.0f) { // To avoid numerical error
124  // Energy = k * (r - r_eq)^n
125  energy = restraintK * distSqDiff;
126  for (int n = 2; n < restraintExp; n += 2) {
127  energy *= distSqDiff;
128  }
129  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
130  double force = -energy * restraintExp / distSqDiff;
131  group_f.x = force * resDist.x;
132  group_f.y = force * resDist.y;
133  group_f.z = force * resDist.z;
134  }
135  }
136 
137  // calculate the force on each atom of the group
138  if (tIdx < numRestrainedGroup1) {
139  // threads [0 , numGroup1Atoms) calculate force for group 1
140  // We use negative because force is calculated for group 2
141  f.x = -group_f.x * m * inv_group1_mass;
142  f.y = -group_f.y * m * inv_group1_mass;
143  f.z = -group_f.z * m * inv_group1_mass;
144  } else {
145  // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
146  f.x = group_f.x * m * inv_group2_mass;
147  f.y = group_f.y * m * inv_group2_mass;
148  f.z = group_f.z * m * inv_group2_mass;
149  }
150  // apply the bias to each atom in group
151  f_normal_x[SOAindex] += f.x;
152  f_normal_y[SOAindex] += f.y;
153  f_normal_z[SOAindex] += f.z;
154  // Virial is based on applied force on each atom
155  if(T_DOVIRIAL) {
156  // positions must be unwraped for virial calculation
157  pos.x = pos_x[SOAindex];
158  pos.y = pos_y[SOAindex];
159  pos.z = pos_z[SOAindex];
160  char3 tr = transform[SOAindex];
161  pos = lat.reverse_transform(pos, tr);
162  r_virial.xx = f.x * pos.x;
163  r_virial.xy = f.x * pos.y;
164  r_virial.xz = f.x * pos.z;
165  r_virial.yx = f.y * pos.x;
166  r_virial.yy = f.y * pos.y;
167  r_virial.yz = f.y * pos.z;
168  r_virial.zx = f.z * pos.x;
169  r_virial.zy = f.z * pos.y;
170  r_virial.zz = f.z * pos.z;
171  }
172  }
173  __syncthreads();
174 
175  if(T_DOENERGY || T_DOVIRIAL) {
176  if(T_DOVIRIAL) {
177  // Reduce virial values in the thread block
178  typedef cub::BlockReduce<double, 128> BlockReduce;
179  __shared__ typename BlockReduce::TempStorage temp_storage;
180 
181  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
182  __syncthreads();
183  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
184  __syncthreads();
185  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
186  __syncthreads();
187 
188  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
189  __syncthreads();
190  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
191  __syncthreads();
192  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
193  __syncthreads();
194 
195  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
196  __syncthreads();
197  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
198  __syncthreads();
199  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
200  __syncthreads();
201  }
202 
203  if(threadIdx.x == 0) {
204  if(T_DOVIRIAL) {
205  // thread 0 adds the reduced virial values into device memory
206  atomicAdd(&(d_virial->xx), r_virial.xx);
207  atomicAdd(&(d_virial->xy), r_virial.xy);
208  atomicAdd(&(d_virial->xz), r_virial.xz);
209 
210  atomicAdd(&(d_virial->yx), r_virial.yx);
211  atomicAdd(&(d_virial->yy), r_virial.yy);
212  atomicAdd(&(d_virial->yz), r_virial.yz);
213 
214  atomicAdd(&(d_virial->zx), r_virial.zx);
215  atomicAdd(&(d_virial->zy), r_virial.zy);
216  atomicAdd(&(d_virial->zz), r_virial.zz);
217  }
218  __threadfence();
219  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
220  isLastBlockDone = (value == (totaltb -1));
221  }
222 
223  __syncthreads();
224 
225  if(isLastBlockDone) {
226  // Thread 0 of the last block will set the host values
227  if(threadIdx.x == 0) {
228  if(T_DOENERGY) {
229  h_resEnergy[0] = energy; // restraint energy for each group, needed for output
230  h_diffCOM->x = diffCOM.x; // distance between COM of two restrained groups
231  h_diffCOM->y = diffCOM.y; // distance between COM of two restrained groups
232  h_diffCOM->z = diffCOM.z; // distance between COM of two restrained groups
233  h_resForce->x = group_f.x; // restraint force on group 2
234  h_resForce->y = group_f.y; // restraint force on group 2
235  h_resForce->z = group_f.z; // restraint force on group 2
236  }
237  if(T_DOVIRIAL) {
238  // Add virial values to host memory.
239  // We use add,since we have with multiple restraints group
240  h_extVirial->xx += d_virial->xx;
241  h_extVirial->xy += d_virial->xy;
242  h_extVirial->xz += d_virial->xz;
243  h_extVirial->yx += d_virial->yx;
244  h_extVirial->yy += d_virial->yy;
245  h_extVirial->yz += d_virial->yz;
246  h_extVirial->zx += d_virial->zx;
247  h_extVirial->zy += d_virial->zy;
248  h_extVirial->zz += d_virial->zz;
249 
250  //reset the device virial value
251  d_virial->xx = 0;
252  d_virial->xy = 0;
253  d_virial->xz = 0;
254 
255  d_virial->yx = 0;
256  d_virial->yy = 0;
257  d_virial->yz = 0;
258 
259  d_virial->zx = 0;
260  d_virial->zy = 0;
261  d_virial->zz = 0;
262  }
263  //resets atomic counter
264  d_tbcatomic[0] = 0;
265  __threadfence();
266  }
267  }
268  }
269 }
270 
271 
272 /*! Compute restraint force, virial, and energy applied to small
273  groups (atoms < 1024), due to restraining COM of group 2
274  (h_group2COM) to the COM of the group 1 (h_group1COM).
275  This function also calculates the distance from COM of the
276  group 1 to COM of the group 2. */
277 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
278 __global__ void computeSmallGroupRestraint2GroupsKernel(
279  const int numRestrainedGroup1,
280  const int totalNumRestrained,
281  const int restraintExp,
282  const double restraintK,
283  const double3 resCenterVec,
284  const double3 resDirection,
285  const double inv_group1_mass,
286  const double inv_group2_mass,
287  const int* __restrict groupAtomsSOAIndex,
288  const Lattice lat,
289  const char3* __restrict transform,
290  const float* __restrict mass,
291  const double* __restrict pos_x,
292  const double* __restrict pos_y,
293  const double* __restrict pos_z,
294  double* __restrict f_normal_x,
295  double* __restrict f_normal_y,
296  double* __restrict f_normal_z,
297  cudaTensor* __restrict h_extVirial,
298  double* __restrict h_resEnergy,
299  double3* __restrict h_resForce,
300  double3* __restrict h_diffCOM)
301 {
302  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
303  __shared__ double3 sh_com1;
304  __shared__ double3 sh_com2;
305 
306  double m = 0;
307  double energy = 0.0;
308  double3 com1 = {0, 0, 0};
309  double3 com2 = {0, 0, 0};
310  double3 diffCOM = {0, 0, 0};
311  double3 group_f = {0, 0, 0};
312  double3 pos = {0, 0, 0};
313  double3 f = {0, 0, 0};
314  cudaTensor r_virial;
315  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
316  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
317  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
318  int SOAindex;
319 
320  if(tIdx < totalNumRestrained){
321  // First -> recalculate center of mass.
322  SOAindex = groupAtomsSOAIndex[tIdx];
323 
324  m = mass[SOAindex]; // Cast from float to double here
325  pos.x = pos_x[SOAindex];
326  pos.y = pos_y[SOAindex];
327  pos.z = pos_z[SOAindex];
328 
329  // unwrap the coordinate to calculate COM
330  char3 tr = transform[SOAindex];
331  pos = lat.reverse_transform(pos, tr);
332 
333  if (tIdx < numRestrainedGroup1) {
334  // we initialized the com2 to zero
335  com1.x = pos.x * m;
336  com1.y = pos.y * m;
337  com1.z = pos.z * m;
338  } else {
339  // we initialized the com1 to zero
340  com2.x = pos.x * m;
341  com2.y = pos.y * m;
342  com2.z = pos.z * m;
343  }
344  }
345 
346  // reduce the (mass * position) values for group 1 and 2 in the thread block
347  typedef cub::BlockReduce<double, 1024> BlockReduce;
348  __shared__ typename BlockReduce::TempStorage temp_storage;
349 
350  com1.x = BlockReduce(temp_storage).Sum(com1.x);
351  __syncthreads();
352  com1.y = BlockReduce(temp_storage).Sum(com1.y);
353  __syncthreads();
354  com1.z = BlockReduce(temp_storage).Sum(com1.z);
355  __syncthreads();
356  com2.x = BlockReduce(temp_storage).Sum(com2.x);
357  __syncthreads();
358  com2.y = BlockReduce(temp_storage).Sum(com2.y);
359  __syncthreads();
360  com2.z = BlockReduce(temp_storage).Sum(com2.z);
361  __syncthreads();
362 
363  // Thread 0 calculates the COM of group 1 and 2
364  if(threadIdx.x == 0){
365  sh_com1.x = com1.x * inv_group1_mass; // calculates the COM of group 1
366  sh_com1.y = com1.y * inv_group1_mass; // calculates the COM of group 1
367  sh_com1.z = com1.z * inv_group1_mass; // calculates the COM of group 1
368  sh_com2.x = com2.x * inv_group2_mass; // calculates the COM of group 2
369  sh_com2.y = com2.y * inv_group2_mass; // calculates the COM of group 2
370  sh_com2.z = com2.z * inv_group2_mass; // calculates the COM of group 2
371  }
372  __syncthreads();
373 
374  if(tIdx < totalNumRestrained) {
375  // Here for consistency with distanceZ, we calculate
376  // distance from com1 to com2 along specific restraint dimention,
377  // so force is acting on group 2
378  diffCOM.x = (sh_com2.x - sh_com1.x) * resDirection.x;
379  diffCOM.y = (sh_com2.y - sh_com1.y) * resDirection.y;
380  diffCOM.z = (sh_com2.z - sh_com1.z) * resDirection.z;
381  // Calculate the minimum image distance
382  diffCOM = lat.delta_from_diff(diffCOM);
383 
384  if (T_USEMAGNITUDE) {
385  // Calculate the difference from equilibrium restraint distance
386  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
387  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
388  resCenterVec.z*resCenterVec.z);
389 
390  double distDiff = (comVal - centerVal);
391  double distSqDiff = distDiff * distDiff;
392  double invCOMVal = 1.0 / comVal;
393 
394  // Calculate energy and force on group of atoms
395  if(distSqDiff > 0.0f) { // To avoid numerical error
396  // Energy = k * (r - r_eq)^n
397  energy = restraintK * distSqDiff;
398  for (int n = 2; n < restraintExp; n += 2) {
399  energy *= distSqDiff;
400  }
401  // Force = -k * n * (r - r_eq)^(n-1)
402  double force = -energy * restraintExp / distDiff;
403  // calculate force along COM difference
404  group_f.x = force * diffCOM.x * invCOMVal;
405  group_f.y = force * diffCOM.y * invCOMVal;
406  group_f.z = force * diffCOM.z * invCOMVal;
407  }
408  } else {
409  // Calculate the difference from equilibrium restraint distance vector
410  // along specific restraint dimention
411  double3 resDist;
412  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
413  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
414  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
415  // Wrap the distance difference (diffCOM - resCenterVec)
416  resDist = lat.delta_from_diff(resDist);
417 
418  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
419 
420  // Calculate energy and force on group of atoms
421  if(distSqDiff > 0.0f) { // To avoid numerical error
422  // Energy = k * (r - r_eq)^n
423  energy = restraintK * distSqDiff;
424  for (int n = 2; n < restraintExp; n += 2) {
425  energy *= distSqDiff;
426  }
427  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
428  double force = -energy * restraintExp / distSqDiff;
429  group_f.x = force * resDist.x;
430  group_f.y = force * resDist.y;
431  group_f.z = force * resDist.z;
432  }
433  }
434 
435  // calculate the force on each atom of the group
436  if (tIdx < numRestrainedGroup1) {
437  // threads [0 , numGroup1Atoms) calculate force for group 1
438  // We use negative because force is calculated for group 2
439  f.x = -group_f.x * m * inv_group1_mass;
440  f.y = -group_f.y * m * inv_group1_mass;
441  f.z = -group_f.z * m * inv_group1_mass;
442  } else {
443  // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
444  f.x = group_f.x * m * inv_group2_mass;
445  f.y = group_f.y * m * inv_group2_mass;
446  f.z = group_f.z * m * inv_group2_mass;
447  }
448 
449  // apply the bias to each atom in group
450  f_normal_x[SOAindex] += f.x ;
451  f_normal_y[SOAindex] += f.y ;
452  f_normal_z[SOAindex] += f.z ;
453  // Virial is based on applied force on each atom
454  if(T_DOVIRIAL){
455  // positions must be unwraped for virial calculation
456  r_virial.xx = f.x * pos.x;
457  r_virial.xy = f.x * pos.y;
458  r_virial.xz = f.x * pos.z;
459  r_virial.yx = f.y * pos.x;
460  r_virial.yy = f.y * pos.y;
461  r_virial.yz = f.y * pos.z;
462  r_virial.zx = f.z * pos.x;
463  r_virial.zy = f.z * pos.y;
464  r_virial.zz = f.z * pos.z;
465  }
466  }
467  __syncthreads();
468 
469  if(T_DOENERGY || T_DOVIRIAL) {
470  if(T_DOVIRIAL){
471  // Reduce virial values in the thread block
472  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
473  __syncthreads();
474  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
475  __syncthreads();
476  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
477  __syncthreads();
478 
479  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
480  __syncthreads();
481  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
482  __syncthreads();
483  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
484  __syncthreads();
485 
486  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
487  __syncthreads();
488  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
489  __syncthreads();
490  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
491  __syncthreads();
492  }
493 
494  // thread zero updates the restraints energy and force
495  if(threadIdx.x == 0){
496  if(T_DOVIRIAL){
497  // Add virial values to host memory.
498  // We use add,since we have with multiple restraints group
499  h_extVirial->xx += r_virial.xx;
500  h_extVirial->xy += r_virial.xy;
501  h_extVirial->xz += r_virial.xz;
502  h_extVirial->yx += r_virial.yx;
503  h_extVirial->yy += r_virial.yy;
504  h_extVirial->yz += r_virial.yz;
505  h_extVirial->zx += r_virial.zx;
506  h_extVirial->zy += r_virial.zy;
507  h_extVirial->zz += r_virial.zz;
508  }
509  if (T_DOENERGY) {
510  h_resEnergy[0] = energy; // restraint energy for each group, needed for output
511  h_diffCOM->x = diffCOM.x; // distance between two COM of restrained groups
512  h_diffCOM->y = diffCOM.y; // distance between two COM of restrained groups
513  h_diffCOM->z = diffCOM.z; // distance between two COM of restrained groups
514  h_resForce->x = group_f.x; // restraint force on group
515  h_resForce->y = group_f.y; // restraint force on group
516  h_resForce->z = group_f.z; // restraint force on group
517  }
518  }
519  }
520 }
521 
522 /*! Compute restraint force, energy, and virial
523  applied to group 2, due to restraining COM of
524  group 2 to the COM of group 1 */
525 void computeGroupRestraint_2Group(
526  const int useMagnitude,
527  const int doEnergy,
528  const int doVirial,
529  const int numRestrainedGroup1,
530  const int totalNumRestrained,
531  const int restraintExp,
532  const double restraintK,
533  const double3 resCenterVec,
534  const double3 resDirection,
535  const double inv_group1_mass,
536  const double inv_group2_mass,
537  const int* d_groupAtomsSOAIndex,
538  const Lattice &lat,
539  const char3* d_transform,
540  const float* d_mass,
541  const double* d_pos_x,
542  const double* d_pos_y,
543  const double* d_pos_z,
544  double* d_f_normal_x,
545  double* d_f_normal_y,
546  double* d_f_normal_z,
547  cudaTensor* d_virial,
548  cudaTensor* h_extVirial,
549  double* h_resEnergy,
550  double3* h_resForce,
551  double3* h_group1COM,
552  double3* h_group2COM,
553  double3* h_diffCOM,
554  double3* d_group1COM,
555  double3* d_group2COM,
556  unsigned int* d_tbcatomic,
557  cudaStream_t stream)
558 {
559  int options = doEnergy + (doVirial << 1) + (useMagnitude << 2);
560 
561  if (totalNumRestrained > 1024) {
562  const int blocks = 128;
563  const int grid = (totalNumRestrained + blocks - 1) / blocks;
564  // first calculate the COM for restraint groups and store it in
565  // h_group1COM and h_group2COM
566  compute2COMKernel<128><<<grid, blocks, 0, stream>>>(
567  numRestrainedGroup1,
568  totalNumRestrained,
569  inv_group1_mass,
570  inv_group2_mass,
571  lat,
572  d_mass,
573  d_pos_x,
574  d_pos_y,
575  d_pos_z,
576  d_transform,
577  d_groupAtomsSOAIndex,
578  d_group1COM,
579  d_group2COM,
580  h_group1COM,
581  h_group2COM,
582  d_tbcatomic);
583 
584  #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
585  computeLargeGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
586  <<<grid, blocks, 0, stream>>>( \
587  numRestrainedGroup1, totalNumRestrained, \
588  restraintExp, restraintK, resCenterVec, resDirection, \
589  inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
590  lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
591  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
592  h_extVirial, h_resEnergy, h_resForce, h_group1COM, \
593  h_group2COM, h_diffCOM, d_tbcatomic);
594 
595  switch(options) {
596  case 0: CALL_LARGE_GROUP_RES(0, 0, 0); break;
597  case 1: CALL_LARGE_GROUP_RES(1, 0, 0); break;
598  case 2: CALL_LARGE_GROUP_RES(0, 1, 0); break;
599  case 3: CALL_LARGE_GROUP_RES(1, 1, 0); break;
600  case 4: CALL_LARGE_GROUP_RES(0, 0, 1); break;
601  case 5: CALL_LARGE_GROUP_RES(1, 0, 1); break;
602  case 6: CALL_LARGE_GROUP_RES(0, 1, 1); break;
603  case 7: CALL_LARGE_GROUP_RES(1, 1, 1); break;
604  }
605 
606  #undef CALL_LARGE_GROUP_RES
607 
608  } else {
609  // For small group of restrained atom, we can just launch
610  // a single threadblock
611  const int blocks = 1024;
612  const int grid = 1;
613 
614  #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
615  computeSmallGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
616  <<<grid, blocks, 0, stream>>>( \
617  numRestrainedGroup1, totalNumRestrained, \
618  restraintExp, restraintK, resCenterVec, resDirection, \
619  inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
620  lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
621  d_f_normal_x, d_f_normal_y, d_f_normal_z, \
622  h_extVirial, h_resEnergy, h_resForce, h_diffCOM);
623 
624  switch(options) {
625  case 0: CALL_SMALL_GROUP_RES(0, 0, 0); break;
626  case 1: CALL_SMALL_GROUP_RES(1, 0, 0); break;
627  case 2: CALL_SMALL_GROUP_RES(0, 1, 0); break;
628  case 3: CALL_SMALL_GROUP_RES(1, 1, 0); break;
629  case 4: CALL_SMALL_GROUP_RES(0, 0, 1); break;
630  case 5: CALL_SMALL_GROUP_RES(1, 0, 1); break;
631  case 6: CALL_SMALL_GROUP_RES(0, 1, 1); break;
632  case 7: CALL_SMALL_GROUP_RES(1, 1, 1); break;
633  }
634 
635  #undef CALL_SMALL_GROUP_RES
636  }
637 
638 }
639 
640 #endif // NODEGROUP_FORCE_REGISTER