NAMD
ComputeGroupRes1GroupCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeGroupRes1GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 /*! Compute restraint force, virial, and energy applied to large
22  group 2 (atoms >= 1024), due to restraining COM of group 2
23  (h_group2COM) to a reference point (h_group1COMRef).
24  To use this function, the COM of the group 2 must be calculated
25  and passed to this function as h_group2COM.
26  This function also calculates the distance from ref point to
27  COM of the group 2. */
28 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
29 __global__ void computeLargeGroupRestraintKernel_1Group(
30  const int numRestrainedGroup,
31  const int restraintExp,
32  const double restraintK,
33  const double3 resCenterVec,
34  const double3 resDirection,
35  const double inv_group2_mass,
36  const int* __restrict groupAtomsSOAIndex,
37  const Lattice lat,
38  const char3* __restrict transform,
39  const float* __restrict mass,
40  const double* __restrict pos_x,
41  const double* __restrict pos_y,
42  const double* __restrict pos_z,
43  double* __restrict f_normal_x,
44  double* __restrict f_normal_y,
45  double* __restrict f_normal_z,
46  cudaTensor* __restrict d_virial,
47  cudaTensor* __restrict h_extVirial,
48  double* __restrict h_resEnergy,
49  double3* __restrict h_resForce,
50  const double3* __restrict h_group1COMRef,
51  double3* __restrict group2COM,
52  double3* __restrict h_diffCOM,
53  unsigned int* __restrict d_tbcatomic)
54 {
55  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
56  int totaltb = gridDim.x;
57  bool isLastBlockDone = false;
58 
59  int SOAindex;
60  double m = 0;
61  double energy = 0.0;
62  double3 diffCOM = {0, 0, 0};
63  double3 group_f = {0, 0, 0};
64  double3 pos = {0, 0, 0};
65  double3 f = {0, 0, 0};
66  double3 com = {group2COM->x,group2COM->y,group2COM->z};
67  cudaTensor r_virial;
68  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
69  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
70  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
71  if(T_MGPUON)
72  {
73  com.x *= inv_group2_mass;
74  com.y *= inv_group2_mass;
75  com.z *= inv_group2_mass;
76  }
77  if(tIdx < numRestrainedGroup) {
78  SOAindex = groupAtomsSOAIndex[tIdx];
79  m = mass[SOAindex];
80  // Calculate distance from ref to com2 along specific restraint dimention
81  diffCOM.x = (com.x - h_group1COMRef->x) * resDirection.x;
82  diffCOM.y = (com.y - h_group1COMRef->y) * resDirection.y;
83  diffCOM.z = (com.z - h_group1COMRef->z) * resDirection.z;
84  // Calculate the minimum image distance
85  diffCOM = lat.delta_from_diff(diffCOM);
86 
87  if (T_USEMAGNITUDE) {
88  // Calculate the difference from equilibrium restraint distance
89  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
90  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
91  resCenterVec.z*resCenterVec.z);
92 
93  double distDiff = (comVal - centerVal);
94  double distSqDiff = distDiff * distDiff;
95  double invCOMVal = 1.0 / comVal;
96 
97  // Calculate energy and force on group of atoms
98  if(distSqDiff > 0.0f) { // To avoid numerical error
99  // Energy = k * (r - r_eq)^n
100  energy = restraintK * distSqDiff;
101  for (int n = 2; n < restraintExp; n += 2) {
102  energy *= distSqDiff;
103  }
104  // Force = -k * n * (r - r_eq)^(n-1)
105  double force = -energy * restraintExp / distDiff;
106  // calculate force along COM difference
107  group_f.x = force * diffCOM.x * invCOMVal;
108  group_f.y = force * diffCOM.y * invCOMVal;
109  group_f.z = force * diffCOM.z * invCOMVal;
110  }
111  } else {
112  // Calculate the difference from equilibrium restraint distance vector
113  // along specific restraint dimention
114  double3 resDist;
115  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
116  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
117  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
118  // Wrap the distance difference (diffCOM - resCenterVec)
119  resDist = lat.delta_from_diff(resDist);
120 
121  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
122 
123  // Calculate energy and force on group of atoms
124  if(distSqDiff > 0.0f) { // To avoid numerical error
125  // Energy = k * (r - r_eq)^n
126  energy = restraintK * distSqDiff;
127  for (int n = 2; n < restraintExp; n += 2) {
128  energy *= distSqDiff;
129  }
130  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
131  double force = -energy * restraintExp / distSqDiff;
132  group_f.x = force * resDist.x;
133  group_f.y = force * resDist.y;
134  group_f.z = force * resDist.z;
135  }
136  }
137 
138  // calculate the force on each atom
139  f.x = group_f.x * m * inv_group2_mass;
140  f.y = group_f.y * m * inv_group2_mass;
141  f.z = group_f.z * m * inv_group2_mass;
142  // apply the bias to each atom in group
143  f_normal_x[SOAindex] += f.x;
144  f_normal_y[SOAindex] += f.y;
145  f_normal_z[SOAindex] += f.z;
146  // Virial is based on applied force on each atom
147  if(T_DOVIRIAL) {
148  // positions must be unwraped for virial calculation
149  pos.x = pos_x[SOAindex];
150  pos.y = pos_y[SOAindex];
151  pos.z = pos_z[SOAindex];
152  char3 tr = transform[SOAindex];
153  pos = lat.reverse_transform(pos, tr);
154  r_virial.xx = f.x * pos.x;
155  r_virial.xy = f.x * pos.y;
156  r_virial.xz = f.x * pos.z;
157  r_virial.yx = f.y * pos.x;
158  r_virial.yy = f.y * pos.y;
159  r_virial.yz = f.y * pos.z;
160  r_virial.zx = f.z * pos.x;
161  r_virial.zy = f.z * pos.y;
162  r_virial.zz = f.z * pos.z;
163  }
164  }
165  __syncthreads();
166 
167  if(T_DOENERGY || T_DOVIRIAL) {
168  if(T_DOVIRIAL) {
169  // Reduce virial values in the thread block
170  typedef cub::BlockReduce<double, 128> BlockReduce;
171  __shared__ typename BlockReduce::TempStorage temp_storage;
172  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
173  __syncthreads();
174  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
175  __syncthreads();
176  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
177  __syncthreads();
178 
179  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
180  __syncthreads();
181  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
182  __syncthreads();
183  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
184  __syncthreads();
185 
186  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
187  __syncthreads();
188  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
189  __syncthreads();
190  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
191  __syncthreads();
192  }
193 
194  if(threadIdx.x == 0) {
195  if(T_DOVIRIAL) {
196  // thread 0 adds the reduced virial values into device memory
197  atomicAdd(&(d_virial->xx), r_virial.xx);
198  atomicAdd(&(d_virial->xy), r_virial.xy);
199  atomicAdd(&(d_virial->xz), r_virial.xz);
200 
201  atomicAdd(&(d_virial->yx), r_virial.yx);
202  atomicAdd(&(d_virial->yy), r_virial.yy);
203  atomicAdd(&(d_virial->yz), r_virial.yz);
204 
205  atomicAdd(&(d_virial->zx), r_virial.zx);
206  atomicAdd(&(d_virial->zy), r_virial.zy);
207  atomicAdd(&(d_virial->zz), r_virial.zz);
208  }
209  __threadfence();
210  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
211  isLastBlockDone = (value == (totaltb -1));
212  }
213 
214  __syncthreads();
215 
216  if(isLastBlockDone) {
217  // Thread 0 of the last block will set the host values
218  if(threadIdx.x == 0) {
219  if(T_DOENERGY) {
220  h_resEnergy[0] = energy; // restraint energy
221  h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
222  h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
223  h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
224  h_resForce->x = group_f.x; // restraint force on group 2
225  h_resForce->y = group_f.y; // restraint force on group 2
226  h_resForce->z = group_f.z; // restraint force on group 2
227  }
228  if(T_DOVIRIAL) {
229  // Add virial values to host memory.
230  // We use add,since we have with multiple restraints group
231  h_extVirial->xx += d_virial->xx;
232  h_extVirial->xy += d_virial->xy;
233  h_extVirial->xz += d_virial->xz;
234  h_extVirial->yx += d_virial->yx;
235  h_extVirial->yy += d_virial->yy;
236  h_extVirial->yz += d_virial->yz;
237  h_extVirial->zx += d_virial->zx;
238  h_extVirial->zy += d_virial->zy;
239  h_extVirial->zz += d_virial->zz;
240 
241  //reset the device virial values
242  d_virial->xx = 0;
243  d_virial->xy = 0;
244  d_virial->xz = 0;
245 
246  d_virial->yx = 0;
247  d_virial->yy = 0;
248  d_virial->yz = 0;
249 
250  d_virial->zx = 0;
251  d_virial->zy = 0;
252  d_virial->zz = 0;
253  }
254  //resets atomic counter
255  d_tbcatomic[0] = 0;
256  __threadfence();
257  }
258  }
259  }
260  else if(T_MGPUON)
261  {// need lastBockDone for T_MGPUON
262  if(threadIdx.x == 0)
263  {
264  __threadfence();
265  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
266  isLastBlockDone = (value == (totaltb -1));
267  }
268  }
269  // last block cleans up
270  if(T_MGPUON) {
271  if(isLastBlockDone){
272  if(threadIdx.x == 0){
273  // zero out for next iteration
274  group2COM->x = 0.0;
275  group2COM->y = 0.0;
276  group2COM->z = 0.0;
277  h_diffCOM->x = diffCOM.x;
278  h_diffCOM->y = diffCOM.y;
279  h_diffCOM->z = diffCOM.z;
280  //resets atomic counter
281  d_tbcatomic[0] = 0;
282  __threadfence();
283  }
284  }
285  }
286 }
287 
288 
289 /*! Compute restraint force, virial, and energy applied to small
290  group 2 (atoms < 1024), due to restraining COM of group 2
291  (h_group2COM) to a reference point (h_group1COMRef).
292  This function also calculates the distance from ref point to
293  COM of the group 2. */
294 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
295 __global__ void computeSmallGroupRestraintKernel_1Group(
296  const int numRestrainedGroup,
297  const int restraintExp,
298  const double restraintK,
299  const double3 resCenterVec,
300  const double3 resDirection,
301  const double inv_group2_mass,
302  const int* __restrict groupAtomsSOAIndex,
303  const Lattice lat,
304  const char3* __restrict transform,
305  const float* __restrict mass,
306  const double* __restrict pos_x,
307  const double* __restrict pos_y,
308  const double* __restrict pos_z,
309  double* __restrict f_normal_x,
310  double* __restrict f_normal_y,
311  double* __restrict f_normal_z,
312  cudaTensor* __restrict h_extVirial,
313  double* __restrict h_resEnergy,
314  double3* __restrict h_resForce,
315  const double3* __restrict h_group1COMRef,
316  double3* __restrict group2COM, // device in T_MGPUON case
317  double3* __restrict h_diffCOM,
318  unsigned int* __restrict d_tbcatomic)
319 {
320  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
321  __shared__ double3 sh_com2;
322  int totaltb = gridDim.x;
323  bool isLastBlockDone = false;
324  double m = 0;
325  double energy = 0.0;
326  double3 com2 = {0, 0, 0};
327  double3 diffCOM = {0, 0, 0};
328  double3 group_f = {0, 0, 0};
329  double3 pos = {0, 0, 0};
330  double3 f = {0, 0, 0};
331  cudaTensor r_virial;
332  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
333  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
334  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
335  int SOAindex;
336 
337  if(tIdx < numRestrainedGroup){
338  // First -> recalculate center of mass.
339  SOAindex = groupAtomsSOAIndex[tIdx];
340 
341  m = mass[SOAindex]; // Cast from float to double here
342  pos.x = pos_x[SOAindex];
343  pos.y = pos_y[SOAindex];
344  pos.z = pos_z[SOAindex];
345 
346  // unwrap the coordinate to calculate COM
347  char3 tr = transform[SOAindex];
348  pos = lat.reverse_transform(pos, tr);
349  if(!T_MGPUON)
350  {
351  com2.x = pos.x * m;
352  com2.y = pos.y * m;
353  com2.z = pos.z * m;
354  }
355  }
356  // reduce the (mass * position) values for the thread block
357  typedef cub::BlockReduce<double, 1024> BlockReduce;
358  __shared__ typename BlockReduce::TempStorage temp_storage;
359  if(!T_MGPUON){
360  com2.x = BlockReduce(temp_storage).Sum(com2.x);
361  __syncthreads();
362  com2.y = BlockReduce(temp_storage).Sum(com2.y);
363  __syncthreads();
364  com2.z = BlockReduce(temp_storage).Sum(com2.z);
365  }
366  __syncthreads();
367  // Thread 0 calculates the COM
368  if(threadIdx.x == 0){
369  if(T_MGPUON){
370  sh_com2.x = group2COM->x * inv_group2_mass;
371  sh_com2.y = group2COM->y * inv_group2_mass;
372  sh_com2.z = group2COM->z * inv_group2_mass;
373  }
374  else{
375  sh_com2.x = com2.x * inv_group2_mass; // calculates the current center of mass
376  sh_com2.y = com2.y * inv_group2_mass; // calculates the current center of mass
377  sh_com2.z = com2.z * inv_group2_mass; // calculates the current center of mass
378  }
379  }
380  __syncthreads();
381 
382  if(tIdx < numRestrainedGroup){
383  // calculate the distance from ref to com2 along specific restraint dimention
384  diffCOM.x = (sh_com2.x - h_group1COMRef->x) * resDirection.x;
385  diffCOM.y = (sh_com2.y - h_group1COMRef->y) * resDirection.y;
386  diffCOM.z = (sh_com2.z - h_group1COMRef->z) * resDirection.z;
387  // Calculate the minimum image distance
388  diffCOM = lat.delta_from_diff(diffCOM);
389 
390  if (T_USEMAGNITUDE) {
391  // Calculate the difference from equilibrium restraint distance
392  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
393  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
394  resCenterVec.z*resCenterVec.z);
395 
396  double distDiff = (comVal - centerVal);
397  double distSqDiff = distDiff * distDiff;
398  double invCOMVal = 1.0 / comVal;
399 
400  // Calculate energy and force on group of atoms
401  if(distSqDiff > 0.0f) { // To avoid numerical error
402  // Energy = k * (r - r_eq)^n
403  energy = restraintK * distSqDiff;
404  for (int n = 2; n < restraintExp; n += 2) {
405  energy *= distSqDiff;
406  }
407  // Force = -k * n * (r - r_eq)^(n-1)
408  double force = -energy * restraintExp / distDiff;
409  // calculate force along COM difference
410  group_f.x = force * diffCOM.x * invCOMVal;
411  group_f.y = force * diffCOM.y * invCOMVal;
412  group_f.z = force * diffCOM.z * invCOMVal;
413  }
414  } else {
415  // Calculate the difference from equilibrium restraint distance vector
416  // along specific restraint dimention
417  double3 resDist;
418  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
419  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
420  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
421  // Wrap the distance difference (diffCOM - resCenterVec)
422  resDist = lat.delta_from_diff(resDist);
423 
424  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
425 
426  // Calculate energy and force on group of atoms
427  if(distSqDiff > 0.0f) { // To avoid numerical error
428  // Energy = k * (r - r_eq)^n
429  energy = restraintK * distSqDiff;
430  for (int n = 2; n < restraintExp; n += 2) {
431  energy *= distSqDiff;
432  }
433  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
434  double force = -energy * restraintExp / distSqDiff;
435  group_f.x = force * resDist.x;
436  group_f.y = force * resDist.y;
437  group_f.z = force * resDist.z;
438  }
439  }
440 
441  // calculate the force on each atom
442  f.x = group_f.x * m * inv_group2_mass;
443  f.y = group_f.y * m * inv_group2_mass;
444  f.z = group_f.z * m * inv_group2_mass;
445  // apply the bias to each atom in group
446  f_normal_x[SOAindex] += f.x ;
447  f_normal_y[SOAindex] += f.y ;
448  f_normal_z[SOAindex] += f.z ;
449  // Virial is based on applied force on each atom
450  if(T_DOVIRIAL){
451  // positions must be unwraped for virial calculation
452  r_virial.xx = f.x * pos.x;
453  r_virial.xy = f.x * pos.y;
454  r_virial.xz = f.x * pos.z;
455  r_virial.yx = f.y * pos.x;
456  r_virial.yy = f.y * pos.y;
457  r_virial.yz = f.y * pos.z;
458  r_virial.zx = f.z * pos.x;
459  r_virial.zy = f.z * pos.y;
460  r_virial.zz = f.z * pos.z;
461  }
462  }
463  __syncthreads();
464 
465  if(T_DOENERGY || T_DOVIRIAL) {
466  if(T_DOVIRIAL){
467  // Reduce virial values in the thread block
468  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
469  __syncthreads();
470  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
471  __syncthreads();
472  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
473  __syncthreads();
474 
475  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
476  __syncthreads();
477  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
478  __syncthreads();
479  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
480  __syncthreads();
481 
482  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
483  __syncthreads();
484  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
485  __syncthreads();
486  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
487  __syncthreads();
488  }
489 
490  // thread zero updates the restraints energy and force
491  if(threadIdx.x == 0){
492  if(T_DOVIRIAL){
493  // Add virial values to host memory.
494  // We use add,since we have with multiple restraints group
495  h_extVirial->xx += r_virial.xx;
496  h_extVirial->xy += r_virial.xy;
497  h_extVirial->xz += r_virial.xz;
498  h_extVirial->yx += r_virial.yx;
499  h_extVirial->yy += r_virial.yy;
500  h_extVirial->yz += r_virial.yz;
501  h_extVirial->zx += r_virial.zx;
502  h_extVirial->zy += r_virial.zy;
503  h_extVirial->zz += r_virial.zz;
504  }
505  if (T_DOENERGY) {
506  h_resEnergy[0] = energy; // restraint energy
507  h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
508  h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
509  h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
510  h_resForce->x = group_f.x; // restraint force on group 2
511  h_resForce->y = group_f.y; // restraint force on group 2
512  h_resForce->z = group_f.z; // restraint force on group 2
513  }
514  }
515  }
516  if(threadIdx.x == 0){
517  __threadfence();
518  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
519  isLastBlockDone = (value == (totaltb -1));
520  }
521  if(T_MGPUON){
522  if(isLastBlockDone){
523  if(threadIdx.x == 0){
524  // zero out for next iteration
525  group2COM->x = 0.0;
526  group2COM->y = 0.0;
527  group2COM->z = 0.0;
528  //resets atomic counter
529  d_tbcatomic[0] = 0;
530  __threadfence();
531  }
532  }
533  }
534 }
535 
536 //standalone kernel caller for debugging
537 void computeDistCOM
538 (
539  const int numRestrainedGroup,
540  double3* d_curCM,
541  double3** d_peerCOM,
542  const int numDevices,
543  cudaStream_t stream)
544 {
545  const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
546  const int grid = (numRestrainedGroup + blocks - 1) / blocks;
547  computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
548  d_curCM,
549  numDevices);
550 }
551 
552 // for debugging, copy from device for output
553 void copyDistCOM
554 (
555  const int numRestrainedGroup,
556  double3** d_peerCOM,
557  double3* h_peerCOM,
558  const int deviceIndex,
559  cudaStream_t stream)
560 {
561  const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
562  const int grid = (numRestrainedGroup + blocks - 1) / blocks;
563  copyDistCOMKernel<<<grid, blocks, 0, stream>>>(d_peerCOM,
564  h_peerCOM,
565  deviceIndex);
566 }
567 
568 // standalone kernel caller for debugging single gpu version
569 void computeLocalCOM
570 (
571  const int numRestrainedGroup,
572  const double inv_group2_mass,
573  const int* d_groupAtomsSOAIndex,
574  const Lattice &lat,
575  const char3* d_transform,
576  const float* d_mass,
577  const double* d_pos_x,
578  const double* d_pos_y,
579  const double* d_pos_z,
580  double3* d_group2COM,
581  double3* h_group2COM,
582  unsigned int* d_tbcatomic,
583  cudaStream_t stream)
584 {
585  const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
586  const int grid = (numRestrainedGroup + blocks - 1) / blocks;
587  computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
588  numRestrainedGroup,
589  inv_group2_mass,
590  lat,
591  d_mass,
592  d_pos_x,
593  d_pos_y,
594  d_pos_z,
595  d_transform,
596  d_groupAtomsSOAIndex,
597  d_group2COM,
598  h_group2COM,
599  d_tbcatomic);
600 }
601 
602 /*! Compute restraint force, energy, and virial
603  applied to group 2, due to restraining COM of
604  group 2 to a reference COM position of group 1
605  (h_group1COMRef) */
606 void computeGroupRestraint_1Group(
607  const bool mGpuOn,
608  const int useMagnitude,
609  const int doEnergy,
610  const int doVirial,
611  const int numRestrainedGroup,
612  const int restraintExp,
613  const double restraintK,
614  const double3 resCenterVec,
615  const double3 resDirection,
616  const double inv_group2_mass,
617  const int* d_groupAtomsSOAIndex,
618  const Lattice &lat,
619  const char3* d_transform,
620  const float* d_mass,
621  const double* d_pos_x,
622  const double* d_pos_y,
623  const double* d_pos_z,
624  double* d_f_normal_x,
625  double* d_f_normal_y,
626  double* d_f_normal_z,
627  cudaTensor* d_virial,
628  cudaTensor* h_extVirial,
629  double* h_resEnergy,
630  double3* h_resForce,
631  double3* h_group1COMRef,
632  double3* h_group2COM,
633  double3* h_diffCOM,
634  double3** d_peer2COM,
635  double3* d_group2COM,
636  unsigned int* d_tbcatomic,
637  int numDevices,
638  cudaStream_t stream)
639 {
640  int options = doEnergy + (doVirial << 1) + (useMagnitude << 2) + (mGpuOn <<3);
641  const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
642  const int grid = (numRestrainedGroup + blocks - 1) / blocks;
643  //multiGPU keeps the COM on the device, not the host
644  double3* COMPtr=(mGpuOn) ? d_group2COM: h_group2COM;
645 
646  if (numRestrainedGroup > 1024) {
647  //first calculate the COM for restraint group and store it in h_group2COM
648  if(!mGpuOn) {
649  computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
650  numRestrainedGroup,
651  inv_group2_mass,
652  lat,
653  d_mass,
654  d_pos_x,
655  d_pos_y,
656  d_pos_z,
657  d_transform,
658  d_groupAtomsSOAIndex,
659  d_group2COM,
660  h_group2COM,
661  d_tbcatomic);
662  }
663  else{
664  // need to compute COM of group2
665  // we only need numDevices threads
666  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
667  d_group2COM,
668  numDevices);
669  }
670 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
671  computeLargeGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
672  <<<grid, blocks, 0, stream>>>( \
673  numRestrainedGroup, restraintExp, restraintK, resCenterVec,\
674  resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
675  d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
676  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
677  h_extVirial, h_resEnergy, h_resForce, h_group1COMRef, \
678  COMPtr, h_diffCOM, d_tbcatomic);
679  switch(options) {
680  // little endian, for unknown reasons
681  case 0: CALL_LARGE_GROUP_RES(0, 0, 0, 0); break;
682  case 1: CALL_LARGE_GROUP_RES(1, 0, 0, 0); break;
683  case 2: CALL_LARGE_GROUP_RES(0, 1, 0, 0); break;
684  case 3: CALL_LARGE_GROUP_RES(1, 1, 0, 0); break;
685  case 4: CALL_LARGE_GROUP_RES(0, 0, 1, 0); break;
686  case 5: CALL_LARGE_GROUP_RES(1, 0, 1, 0); break;
687  case 6: CALL_LARGE_GROUP_RES(0, 1, 1, 0); break;
688  case 7: CALL_LARGE_GROUP_RES(1, 1, 1, 0); break;
689  case 8: CALL_LARGE_GROUP_RES(0, 0, 0, 1); break;
690  case 9: CALL_LARGE_GROUP_RES(1, 0, 0, 1); break;
691  case 10: CALL_LARGE_GROUP_RES(0, 1, 0, 1); break;
692  case 11: CALL_LARGE_GROUP_RES(1, 1, 0, 1); break;
693  case 12: CALL_LARGE_GROUP_RES(0, 0, 1, 1); break;
694  case 13: CALL_LARGE_GROUP_RES(1, 0, 1, 1); break;
695  case 14: CALL_LARGE_GROUP_RES(0, 1, 1, 1); break;
696  case 15: CALL_LARGE_GROUP_RES(1, 1, 1, 1); break;
697  }
698 #undef CALL_LARGE_GROUP_RES
699  } else {
700  // For small group of restrained atom, we can just launch
701  // a single threadblock
702 
703 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
704  computeSmallGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
705  <<<grid, blocks, 0, stream>>>( \
706  numRestrainedGroup, restraintExp, restraintK, resCenterVec,\
707  resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
708  d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
709  d_f_normal_x, d_f_normal_y, d_f_normal_z, \
710  h_extVirial, h_resEnergy, h_resForce, \
711  h_group1COMRef, COMPtr, h_diffCOM, d_tbcatomic);
712 
713  if(mGpuOn){
714  // need to compute COM of group2
715  // we only need numDevices threads
716  computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
717  d_group2COM,
718  numDevices);
719  }
720  switch(options) {
721  // little endian as above, because bizarro land
722  case 0: CALL_SMALL_GROUP_RES(0, 0, 0, 0); break;
723  case 1: CALL_SMALL_GROUP_RES(1, 0, 0, 0); break;
724  case 2: CALL_SMALL_GROUP_RES(0, 1, 0, 0); break;
725  case 3: CALL_SMALL_GROUP_RES(1, 1, 0, 0); break;
726  case 4: CALL_SMALL_GROUP_RES(0, 0, 1, 0); break;
727  case 5: CALL_SMALL_GROUP_RES(1, 0, 1, 0); break;
728  case 6: CALL_SMALL_GROUP_RES(0, 1, 1, 0); break;
729  case 7: CALL_SMALL_GROUP_RES(1, 1, 1, 0); break;
730  case 8: CALL_SMALL_GROUP_RES(0, 0, 0, 1); break;
731  case 9: CALL_SMALL_GROUP_RES(1, 0, 0, 1); break;
732  case 10: CALL_SMALL_GROUP_RES(0, 1, 0, 1); break;
733  case 11: CALL_SMALL_GROUP_RES(1, 1, 0, 1); break;
734  case 12: CALL_SMALL_GROUP_RES(0, 0, 1, 1); break;
735  case 13: CALL_SMALL_GROUP_RES(1, 0, 1, 1); break;
736  case 14: CALL_SMALL_GROUP_RES(0, 1, 1, 1); break;
737  case 15: CALL_SMALL_GROUP_RES(1, 1, 1, 1); break;
738  }
739 #undef CALL_SMALL_GROUP_RES
740  }
741 }
742 void initPeerCOMmgpuG(
743  const int numDevices,
744  const int deviceIndex,
745  double3** d_peerPool,
746  double3* d_peerCOM,
747  cudaStream_t stream)
748 {
749  const int blocks = numDevices;
750  const int grid = 1;
751  initPeerCOMKernel<<<grid, blocks, 0, stream>>>( numDevices,
752  deviceIndex,
753  d_peerPool,
754  d_peerCOM);
755 }
756 
757 /* called in earlier phase to handle multi device COM */
758 void computeCOMMgpu(
759  const int numAtoms,
760  const Lattice &lat,
761  const float* d_mass,
762  const double* d_pos_x,
763  const double* d_pos_y,
764  const double* d_pos_z,
765  const char3* d_transform,
766  const int* d_AtomsSOAIndex,
767  double3* d_peerCOM,
768  double3** d_peer_curCM,
769  unsigned int* d_tbcatomic,
770  const int numDevices,
771  const int deviceIndex,
772  cudaStream_t stream)
773 {
774  // block it up if large, otherwise all in one go
775  const int blocks = (numAtoms > 1024) ? 128 : 1024;
776  const int grid = (numAtoms + blocks - 1) / blocks;
777  //initialize the device memory to zero here
778  cudaCheck(cudaMemset(d_peerCOM, 0, sizeof(double3)));
779  // why do we need this here?
780  if(numAtoms >1024)
781  computeCOMKernelMgpu<128><<<grid, blocks, 0, stream>>>(numAtoms,
782  lat, d_mass,
783  d_pos_x, d_pos_y, d_pos_z,
784  d_transform,
785  d_AtomsSOAIndex,
786  d_peer_curCM,
787  numDevices,
788  deviceIndex,
789  d_tbcatomic);
790  else
791  computeCOMKernelMgpu<1024><<<grid, blocks, 0, stream>>>(numAtoms,
792  lat, d_mass,
793  d_pos_x, d_pos_y, d_pos_z,
794  d_transform,
795  d_AtomsSOAIndex,
796  d_peer_curCM,
797  numDevices,
798  deviceIndex,
799  d_tbcatomic);
800 }
801 
802 #endif // NODEGROUP_FORCE_REGISTER