NAMD
ComputeGroupRes1GroupCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #endif
8 
9 #ifdef NAMD_HIP
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
12 #define cub hipcub
13 #endif
14 
15 #include "ComputeGroupRes1GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
18 
19 #ifdef NODEGROUP_FORCE_REGISTER
20 
21 /*! Compute restraint force, virial, and energy applied to large
22  group 2 (atoms >= 1024), due to restraining COM of group 2
23  (h_group2COM) to a reference point (h_group1COMRef).
24  To use this function, the COM of the group 2 must be calculated
25  and passed to this function as h_group2COM.
26  This function also calculates the distance from ref point to
27  COM of the group 2. */
28 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
29 __global__ void computeLargeGroupRestraintKernel_1Group(
30  const int numRestrainedGroup,
31  const int restraintExp,
32  const double restraintK,
33  const double3 resCenterVec,
34  const double3 resDirection,
35  const double inv_group2_mass,
36  const int* __restrict groupAtomsSOAIndex,
37  const Lattice lat,
38  const char3* __restrict transform,
39  const float* __restrict mass,
40  const double* __restrict pos_x,
41  const double* __restrict pos_y,
42  const double* __restrict pos_z,
43  double* __restrict f_normal_x,
44  double* __restrict f_normal_y,
45  double* __restrict f_normal_z,
46  cudaTensor* __restrict d_virial,
47  cudaTensor* __restrict h_extVirial,
48  double* __restrict h_resEnergy,
49  double3* __restrict h_resForce,
50  const double3* __restrict h_group1COMRef,
51  const double3* __restrict h_group2COM,
52  double3* __restrict h_diffCOM,
53  unsigned int* __restrict d_tbcatomic)
54 {
55  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
56  int totaltb = gridDim.x;
57  bool isLastBlockDone = false;
58 
59  int SOAindex;
60  double m = 0;
61  double energy = 0.0;
62  double3 diffCOM = {0, 0, 0};
63  double3 group_f = {0, 0, 0};
64  double3 pos = {0, 0, 0};
65  double3 f = {0, 0, 0};
66  cudaTensor r_virial;
67  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
68  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
69  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
70 
71  if(tIdx < numRestrainedGroup) {
72  SOAindex = groupAtomsSOAIndex[tIdx];
73  m = mass[SOAindex];
74  // Calculate distance from ref to com2 along specific restraint dimention
75  diffCOM.x = (h_group2COM->x - h_group1COMRef->x) * resDirection.x;
76  diffCOM.y = (h_group2COM->y - h_group1COMRef->y) * resDirection.y;
77  diffCOM.z = (h_group2COM->z - h_group1COMRef->z) * resDirection.z;
78  // Calculate the minimum image distance
79  diffCOM = lat.delta_from_diff(diffCOM);
80 
81  if (T_USEMAGNITUDE) {
82  // Calculate the difference from equilibrium restraint distance
83  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
84  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
85  resCenterVec.z*resCenterVec.z);
86 
87  double distDiff = (comVal - centerVal);
88  double distSqDiff = distDiff * distDiff;
89  double invCOMVal = 1.0 / comVal;
90 
91  // Calculate energy and force on group of atoms
92  if(distSqDiff > 0.0f) { // To avoid numerical error
93  // Energy = k * (r - r_eq)^n
94  energy = restraintK * distSqDiff;
95  for (int n = 2; n < restraintExp; n += 2) {
96  energy *= distSqDiff;
97  }
98  // Force = -k * n * (r - r_eq)^(n-1)
99  double force = -energy * restraintExp / distDiff;
100  // calculate force along COM difference
101  group_f.x = force * diffCOM.x * invCOMVal;
102  group_f.y = force * diffCOM.y * invCOMVal;
103  group_f.z = force * diffCOM.z * invCOMVal;
104  }
105  } else {
106  // Calculate the difference from equilibrium restraint distance vector
107  // along specific restraint dimention
108  double3 resDist;
109  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
110  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
111  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
112  // Wrap the distance difference (diffCOM - resCenterVec)
113  resDist = lat.delta_from_diff(resDist);
114 
115  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
116 
117  // Calculate energy and force on group of atoms
118  if(distSqDiff > 0.0f) { // To avoid numerical error
119  // Energy = k * (r - r_eq)^n
120  energy = restraintK * distSqDiff;
121  for (int n = 2; n < restraintExp; n += 2) {
122  energy *= distSqDiff;
123  }
124  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
125  double force = -energy * restraintExp / distSqDiff;
126  group_f.x = force * resDist.x;
127  group_f.y = force * resDist.y;
128  group_f.z = force * resDist.z;
129  }
130  }
131 
132  // calculate the force on each atom
133  f.x = group_f.x * m * inv_group2_mass;
134  f.y = group_f.y * m * inv_group2_mass;
135  f.z = group_f.z * m * inv_group2_mass;
136  // apply the bias to each atom in group
137  f_normal_x[SOAindex] += f.x;
138  f_normal_y[SOAindex] += f.y;
139  f_normal_z[SOAindex] += f.z;
140  // Virial is based on applied force on each atom
141  if(T_DOVIRIAL) {
142  // positions must be unwraped for virial calculation
143  pos.x = pos_x[SOAindex];
144  pos.y = pos_y[SOAindex];
145  pos.z = pos_z[SOAindex];
146  char3 tr = transform[SOAindex];
147  pos = lat.reverse_transform(pos, tr);
148  r_virial.xx = f.x * pos.x;
149  r_virial.xy = f.x * pos.y;
150  r_virial.xz = f.x * pos.z;
151  r_virial.yx = f.y * pos.x;
152  r_virial.yy = f.y * pos.y;
153  r_virial.yz = f.y * pos.z;
154  r_virial.zx = f.z * pos.x;
155  r_virial.zy = f.z * pos.y;
156  r_virial.zz = f.z * pos.z;
157  }
158  }
159  __syncthreads();
160 
161  if(T_DOENERGY || T_DOVIRIAL) {
162  if(T_DOVIRIAL) {
163  // Reduce virial values in the thread block
164  typedef cub::BlockReduce<double, 128> BlockReduce;
165  __shared__ typename BlockReduce::TempStorage temp_storage;
166  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
167  __syncthreads();
168  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
169  __syncthreads();
170  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
171  __syncthreads();
172 
173  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
174  __syncthreads();
175  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
176  __syncthreads();
177  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
178  __syncthreads();
179 
180  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
181  __syncthreads();
182  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
183  __syncthreads();
184  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
185  __syncthreads();
186  }
187 
188  if(threadIdx.x == 0) {
189  if(T_DOVIRIAL) {
190  // thread 0 adds the reduced virial values into device memory
191  atomicAdd(&(d_virial->xx), r_virial.xx);
192  atomicAdd(&(d_virial->xy), r_virial.xy);
193  atomicAdd(&(d_virial->xz), r_virial.xz);
194 
195  atomicAdd(&(d_virial->yx), r_virial.yx);
196  atomicAdd(&(d_virial->yy), r_virial.yy);
197  atomicAdd(&(d_virial->yz), r_virial.yz);
198 
199  atomicAdd(&(d_virial->zx), r_virial.zx);
200  atomicAdd(&(d_virial->zy), r_virial.zy);
201  atomicAdd(&(d_virial->zz), r_virial.zz);
202  }
203  __threadfence();
204  unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
205  isLastBlockDone = (value == (totaltb -1));
206  }
207 
208  __syncthreads();
209 
210  if(isLastBlockDone) {
211  // Thread 0 of the last block will set the host values
212  if(threadIdx.x == 0) {
213  if(T_DOENERGY) {
214  h_resEnergy[0] = energy; // restraint energy
215  h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
216  h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
217  h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
218  h_resForce->x = group_f.x; // restraint force on group 2
219  h_resForce->y = group_f.y; // restraint force on group 2
220  h_resForce->z = group_f.z; // restraint force on group 2
221  }
222  if(T_DOVIRIAL) {
223  // Add virial values to host memory.
224  // We use add,since we have with multiple restraints group
225  h_extVirial->xx += d_virial->xx;
226  h_extVirial->xy += d_virial->xy;
227  h_extVirial->xz += d_virial->xz;
228  h_extVirial->yx += d_virial->yx;
229  h_extVirial->yy += d_virial->yy;
230  h_extVirial->yz += d_virial->yz;
231  h_extVirial->zx += d_virial->zx;
232  h_extVirial->zy += d_virial->zy;
233  h_extVirial->zz += d_virial->zz;
234 
235  //reset the device virial values
236  d_virial->xx = 0;
237  d_virial->xy = 0;
238  d_virial->xz = 0;
239 
240  d_virial->yx = 0;
241  d_virial->yy = 0;
242  d_virial->yz = 0;
243 
244  d_virial->zx = 0;
245  d_virial->zy = 0;
246  d_virial->zz = 0;
247  }
248  //resets atomic counter
249  d_tbcatomic[0] = 0;
250  __threadfence();
251  }
252  }
253  }
254 }
255 
256 
257 /*! Compute restraint force, virial, and energy applied to small
258  group 2 (atoms < 1024), due to restraining COM of group 2
259  (h_group2COM) to a reference point (h_group1COMRef).
260  This function also calculates the distance from ref point to
261  COM of the group 2. */
262 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
263 __global__ void computeSmallGroupRestraintKernel_1Group(
264  const int numRestrainedGroup,
265  const int restraintExp,
266  const double restraintK,
267  const double3 resCenterVec,
268  const double3 resDirection,
269  const double inv_group2_mass,
270  const int* __restrict groupAtomsSOAIndex,
271  const Lattice lat,
272  const char3* __restrict transform,
273  const float* __restrict mass,
274  const double* __restrict pos_x,
275  const double* __restrict pos_y,
276  const double* __restrict pos_z,
277  double* __restrict f_normal_x,
278  double* __restrict f_normal_y,
279  double* __restrict f_normal_z,
280  cudaTensor* __restrict h_extVirial,
281  double* __restrict h_resEnergy,
282  double3* __restrict h_resForce,
283  const double3* __restrict h_group1COMRef,
284  double3* __restrict h_diffCOM)
285 {
286  int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
287  __shared__ double3 sh_com2;
288 
289  double m = 0;
290  double energy = 0.0;
291  double3 com2 = {0, 0, 0};
292  double3 diffCOM = {0, 0, 0};
293  double3 group_f = {0, 0, 0};
294  double3 pos = {0, 0, 0};
295  double3 f = {0, 0, 0};
296  cudaTensor r_virial;
297  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
298  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
299  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
300  int SOAindex;
301 
302  if(tIdx < numRestrainedGroup){
303  // First -> recalculate center of mass.
304  SOAindex = groupAtomsSOAIndex[tIdx];
305 
306  m = mass[SOAindex]; // Cast from float to double here
307  pos.x = pos_x[SOAindex];
308  pos.y = pos_y[SOAindex];
309  pos.z = pos_z[SOAindex];
310 
311  // unwrap the coordinate to calculate COM
312  char3 tr = transform[SOAindex];
313  pos = lat.reverse_transform(pos, tr);
314 
315  com2.x = pos.x * m;
316  com2.y = pos.y * m;
317  com2.z = pos.z * m;
318  }
319 
320  // reduce the (mass * position) values for the thread block
321  typedef cub::BlockReduce<double, 1024> BlockReduce;
322  __shared__ typename BlockReduce::TempStorage temp_storage;
323 
324  com2.x = BlockReduce(temp_storage).Sum(com2.x);
325  __syncthreads();
326  com2.y = BlockReduce(temp_storage).Sum(com2.y);
327  __syncthreads();
328  com2.z = BlockReduce(temp_storage).Sum(com2.z);
329  __syncthreads();
330 
331  // Thread 0 calculates the COM
332  if(threadIdx.x == 0){
333  sh_com2.x = com2.x * inv_group2_mass; // calculates the current center of mass
334  sh_com2.y = com2.y * inv_group2_mass; // calculates the current center of mass
335  sh_com2.z = com2.z * inv_group2_mass; // calculates the current center of mass
336  }
337  __syncthreads();
338 
339  if(tIdx < numRestrainedGroup){
340  // calculate the distance from ref to com2 along specific restraint dimention
341  diffCOM.x = (sh_com2.x - h_group1COMRef->x) * resDirection.x;
342  diffCOM.y = (sh_com2.y - h_group1COMRef->y) * resDirection.y;
343  diffCOM.z = (sh_com2.z - h_group1COMRef->z) * resDirection.z;
344  // Calculate the minimum image distance
345  diffCOM = lat.delta_from_diff(diffCOM);
346 
347  if (T_USEMAGNITUDE) {
348  // Calculate the difference from equilibrium restraint distance
349  double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
350  double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
351  resCenterVec.z*resCenterVec.z);
352 
353  double distDiff = (comVal - centerVal);
354  double distSqDiff = distDiff * distDiff;
355  double invCOMVal = 1.0 / comVal;
356 
357  // Calculate energy and force on group of atoms
358  if(distSqDiff > 0.0f) { // To avoid numerical error
359  // Energy = k * (r - r_eq)^n
360  energy = restraintK * distSqDiff;
361  for (int n = 2; n < restraintExp; n += 2) {
362  energy *= distSqDiff;
363  }
364  // Force = -k * n * (r - r_eq)^(n-1)
365  double force = -energy * restraintExp / distDiff;
366  // calculate force along COM difference
367  group_f.x = force * diffCOM.x * invCOMVal;
368  group_f.y = force * diffCOM.y * invCOMVal;
369  group_f.z = force * diffCOM.z * invCOMVal;
370  }
371  } else {
372  // Calculate the difference from equilibrium restraint distance vector
373  // along specific restraint dimention
374  double3 resDist;
375  resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
376  resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
377  resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
378  // Wrap the distance difference (diffCOM - resCenterVec)
379  resDist = lat.delta_from_diff(resDist);
380 
381  double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
382 
383  // Calculate energy and force on group of atoms
384  if(distSqDiff > 0.0f) { // To avoid numerical error
385  // Energy = k * (r - r_eq)^n
386  energy = restraintK * distSqDiff;
387  for (int n = 2; n < restraintExp; n += 2) {
388  energy *= distSqDiff;
389  }
390  // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
391  double force = -energy * restraintExp / distSqDiff;
392  group_f.x = force * resDist.x;
393  group_f.y = force * resDist.y;
394  group_f.z = force * resDist.z;
395  }
396  }
397 
398  // calculate the force on each atom
399  f.x = group_f.x * m * inv_group2_mass;
400  f.y = group_f.y * m * inv_group2_mass;
401  f.z = group_f.z * m * inv_group2_mass;
402  // apply the bias to each atom in group
403  f_normal_x[SOAindex] += f.x ;
404  f_normal_y[SOAindex] += f.y ;
405  f_normal_z[SOAindex] += f.z ;
406  // Virial is based on applied force on each atom
407  if(T_DOVIRIAL){
408  // positions must be unwraped for virial calculation
409  r_virial.xx = f.x * pos.x;
410  r_virial.xy = f.x * pos.y;
411  r_virial.xz = f.x * pos.z;
412  r_virial.yx = f.y * pos.x;
413  r_virial.yy = f.y * pos.y;
414  r_virial.yz = f.y * pos.z;
415  r_virial.zx = f.z * pos.x;
416  r_virial.zy = f.z * pos.y;
417  r_virial.zz = f.z * pos.z;
418  }
419  }
420  __syncthreads();
421 
422  if(T_DOENERGY || T_DOVIRIAL) {
423  if(T_DOVIRIAL){
424  // Reduce virial values in the thread block
425  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
426  __syncthreads();
427  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
428  __syncthreads();
429  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
430  __syncthreads();
431 
432  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
433  __syncthreads();
434  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
435  __syncthreads();
436  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
437  __syncthreads();
438 
439  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
440  __syncthreads();
441  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
442  __syncthreads();
443  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
444  __syncthreads();
445  }
446 
447  // thread zero updates the restraints energy and force
448  if(threadIdx.x == 0){
449  if(T_DOVIRIAL){
450  // Add virial values to host memory.
451  // We use add,since we have with multiple restraints group
452  h_extVirial->xx += r_virial.xx;
453  h_extVirial->xy += r_virial.xy;
454  h_extVirial->xz += r_virial.xz;
455  h_extVirial->yx += r_virial.yx;
456  h_extVirial->yy += r_virial.yy;
457  h_extVirial->yz += r_virial.yz;
458  h_extVirial->zx += r_virial.zx;
459  h_extVirial->zy += r_virial.zy;
460  h_extVirial->zz += r_virial.zz;
461  }
462  if (T_DOENERGY) {
463  h_resEnergy[0] = energy; // restraint energy
464  h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
465  h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
466  h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
467  h_resForce->x = group_f.x; // restraint force on group 2
468  h_resForce->y = group_f.y; // restraint force on group 2
469  h_resForce->z = group_f.z; // restraint force on group 2
470  }
471  }
472  }
473 }
474 
475 /*! Compute restraint force, energy, and virial
476  applied to group 2, due to restraining COM of
477  group 2 to a reference COM position of group 1
478  (h_group1COMRef) */
479 void computeGroupRestraint_1Group(
480  const int useMagnitude,
481  const int doEnergy,
482  const int doVirial,
483  const int numRestrainedGroup,
484  const int restraintExp,
485  const double restraintK,
486  const double3 resCenterVec,
487  const double3 resDirection,
488  const double inv_group2_mass,
489  const int* d_groupAtomsSOAIndex,
490  const Lattice &lat,
491  const char3* d_transform,
492  const float* d_mass,
493  const double* d_pos_x,
494  const double* d_pos_y,
495  const double* d_pos_z,
496  double* d_f_normal_x,
497  double* d_f_normal_y,
498  double* d_f_normal_z,
499  cudaTensor* d_virial,
500  cudaTensor* h_extVirial,
501  double* h_resEnergy,
502  double3* h_resForce,
503  double3* h_group1COMRef,
504  double3* h_group2COM,
505  double3* h_diffCOM,
506  double3* d_group2COM,
507  unsigned int* d_tbcatomic,
508  cudaStream_t stream)
509 {
510  int options = doEnergy + (doVirial << 1) + (useMagnitude << 2);
511 
512  if (numRestrainedGroup > 1024) {
513  const int blocks = 128;
514  const int grid = (numRestrainedGroup + blocks - 1) / blocks;
515  //first calculate the COM for restraint group and store it in h_group2COM
516  computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
517  numRestrainedGroup,
518  inv_group2_mass,
519  lat,
520  d_mass,
521  d_pos_x,
522  d_pos_y,
523  d_pos_z,
524  d_transform,
525  d_groupAtomsSOAIndex,
526  d_group2COM,
527  h_group2COM,
528  d_tbcatomic);
529 
530  #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
531  computeLargeGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
532  <<<grid, blocks, 0, stream>>>( \
533  numRestrainedGroup, restraintExp, restraintK, resCenterVec, \
534  resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
535  d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
536  d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
537  h_extVirial, h_resEnergy, h_resForce, h_group1COMRef, \
538  h_group2COM, h_diffCOM, d_tbcatomic);
539 
540  switch(options) {
541  case 0: CALL_LARGE_GROUP_RES(0, 0, 0); break;
542  case 1: CALL_LARGE_GROUP_RES(1, 0, 0); break;
543  case 2: CALL_LARGE_GROUP_RES(0, 1, 0); break;
544  case 3: CALL_LARGE_GROUP_RES(1, 1, 0); break;
545  case 4: CALL_LARGE_GROUP_RES(0, 0, 1); break;
546  case 5: CALL_LARGE_GROUP_RES(1, 0, 1); break;
547  case 6: CALL_LARGE_GROUP_RES(0, 1, 1); break;
548  case 7: CALL_LARGE_GROUP_RES(1, 1, 1); break;
549  }
550 
551  #undef CALL_LARGE_GROUP_RES
552 
553  } else {
554  // For small group of restrained atom, we can just launch
555  // a single threadblock
556  const int blocks = 1024;
557  const int grid = 1;
558 
559  #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
560  computeSmallGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
561  <<<grid, blocks, 0, stream>>>( \
562  numRestrainedGroup, restraintExp, restraintK, resCenterVec, \
563  resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
564  d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
565  d_f_normal_x, d_f_normal_y, d_f_normal_z, \
566  h_extVirial, h_resEnergy, h_resForce, \
567  h_group1COMRef, h_diffCOM);
568 
569 
570  switch(options) {
571  case 0: CALL_SMALL_GROUP_RES(0, 0, 0); break;
572  case 1: CALL_SMALL_GROUP_RES(1, 0, 0); break;
573  case 2: CALL_SMALL_GROUP_RES(0, 1, 0); break;
574  case 3: CALL_SMALL_GROUP_RES(1, 1, 0); break;
575  case 4: CALL_SMALL_GROUP_RES(0, 0, 1); break;
576  case 5: CALL_SMALL_GROUP_RES(1, 0, 1); break;
577  case 6: CALL_SMALL_GROUP_RES(0, 1, 1); break;
578  case 7: CALL_SMALL_GROUP_RES(1, 1, 1); break;
579  }
580 
581  #undef CALL_SMALL_GROUP_RES
582  }
583 
584 }
585 
586 #endif // NODEGROUP_FORCE_REGISTER