2 #if __CUDACC_VER_MAJOR__ >= 11
5 #include <namd_cub/cub.cuh>
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
15 #include "ComputeGroupRes2GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
19 #ifdef NODEGROUP_FORCE_REGISTER
21 /*! Compute restraint force, virial, and energy applied to large
22 group 2 (atoms >= 1024), due to restraining COM of group 2
23 (h_group2COM) to the COM of the group 1 (h_group1COM).
24 To use this function, the COM of the group 1 and 2
25 must be calculated and passed to this function as h_group1COM
27 This function also calculates the distance from COM of the
28 group 1 to COM of the group 2. */
29 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
30 __global__ void computeLargeGroupRestraint2GroupsKernel(
31 const int numRestrainedGroup1,
32 const int totalNumRestrained,
33 const int restraintExp,
34 const double restraintK,
35 const double3 resCenterVec,
36 const double3 resDirection,
37 const double inv_group1_mass,
38 const double inv_group2_mass,
39 const int* __restrict groupAtomsSOAIndex,
41 const char3* __restrict transform,
42 const float* __restrict mass,
43 const double* __restrict pos_x,
44 const double* __restrict pos_y,
45 const double* __restrict pos_z,
46 double* __restrict f_normal_x,
47 double* __restrict f_normal_y,
48 double* __restrict f_normal_z,
49 cudaTensor* __restrict d_virial,
50 cudaTensor* __restrict h_extVirial,
51 double* __restrict h_resEnergy,
52 double3* __restrict h_resForce,
53 double3* __restrict group1COM, // on device if mgpu
54 double3* __restrict group2COM, // on device if mgpu
55 double3* __restrict h_diffCOM,
56 unsigned int* __restrict d_tbcatomic)
58 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
59 int totaltb = gridDim.x;
60 bool isLastBlockDone = false;
65 double3 diffCOM = {0, 0, 0};
66 double3 group_f = {0, 0, 0};
67 double3 pos = {0, 0, 0};
68 double3 f = {0, 0, 0};
70 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
71 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
72 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
74 if(tIdx < totalNumRestrained) {
75 SOAindex = groupAtomsSOAIndex[tIdx];
77 // Here for consistency with ComputeGroupRes1, we calculate
78 // distance from com1 to com2 along specific restraint dimention,
79 // so force is acting on group 2
81 diffCOM.x = (group2COM->x*inv_group2_mass - group1COM->x*inv_group1_mass) * resDirection.x;
82 diffCOM.y = (group2COM->y*inv_group2_mass - group1COM->y*inv_group1_mass) * resDirection.y;
83 diffCOM.z = (group2COM->z*inv_group2_mass - group1COM->z*inv_group1_mass) * resDirection.z;
87 diffCOM.x = (group2COM->x - group1COM->x) * resDirection.x;
88 diffCOM.y = (group2COM->y - group1COM->y) * resDirection.y;
89 diffCOM.z = (group2COM->z - group1COM->z) * resDirection.z;
91 // Calculate the minimum image distance
92 diffCOM = lat.delta_from_diff(diffCOM);
95 // Calculate the difference from equilibrium restraint distance
96 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
97 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
98 resCenterVec.z*resCenterVec.z);
100 double distDiff = (comVal - centerVal);
101 double distSqDiff = distDiff * distDiff;
102 double invCOMVal = 1.0 / comVal;
104 // Calculate energy and force on group of atoms
105 if(distSqDiff > 0.0f) { // To avoid numerical error
106 // Energy = k * (r - r_eq)^n
107 energy = restraintK * distSqDiff;
108 for (int n = 2; n < restraintExp; n += 2) {
109 energy *= distSqDiff;
111 // Force = -k * n * (r - r_eq)^(n-1)
112 double force = -energy * restraintExp / distDiff;
113 // calculate force along COM difference
114 group_f.x = force * diffCOM.x * invCOMVal;
115 group_f.y = force * diffCOM.y * invCOMVal;
116 group_f.z = force * diffCOM.z * invCOMVal;
119 // Calculate the difference from equilibrium restraint distance vector
120 // along specific restraint dimention
122 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
123 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
124 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
125 // Wrap the distance difference (diffCOM - resCenterVec)
126 resDist = lat.delta_from_diff(resDist);
128 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
130 // Calculate energy and force on group of atoms
131 if(distSqDiff > 0.0f) { // To avoid numerical error
132 // Energy = k * (r - r_eq)^n
133 energy = restraintK * distSqDiff;
134 for (int n = 2; n < restraintExp; n += 2) {
135 energy *= distSqDiff;
137 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
138 double force = -energy * restraintExp / distSqDiff;
139 group_f.x = force * resDist.x;
140 group_f.y = force * resDist.y;
141 group_f.z = force * resDist.z;
145 // calculate the force on each atom of the group
146 if (tIdx < numRestrainedGroup1) {
147 // threads [0 , numGroup1Atoms) calculate force for group 1
148 // We use negative because force is calculated for group 2
149 f.x = -group_f.x * m * inv_group1_mass;
150 f.y = -group_f.y * m * inv_group1_mass;
151 f.z = -group_f.z * m * inv_group1_mass;
153 // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
154 f.x = group_f.x * m * inv_group2_mass;
155 f.y = group_f.y * m * inv_group2_mass;
156 f.z = group_f.z * m * inv_group2_mass;
158 // apply the bias to each atom in group
159 f_normal_x[SOAindex] += f.x;
160 f_normal_y[SOAindex] += f.y;
161 f_normal_z[SOAindex] += f.z;
162 // Virial is based on applied force on each atom
164 // positions must be unwraped for virial calculation
165 pos.x = pos_x[SOAindex];
166 pos.y = pos_y[SOAindex];
167 pos.z = pos_z[SOAindex];
168 char3 tr = transform[SOAindex];
169 pos = lat.reverse_transform(pos, tr);
170 r_virial.xx = f.x * pos.x;
171 r_virial.xy = f.x * pos.y;
172 r_virial.xz = f.x * pos.z;
173 r_virial.yx = f.y * pos.x;
174 r_virial.yy = f.y * pos.y;
175 r_virial.yz = f.y * pos.z;
176 r_virial.zx = f.z * pos.x;
177 r_virial.zy = f.z * pos.y;
178 r_virial.zz = f.z * pos.z;
183 if(T_DOENERGY || T_DOVIRIAL) {
185 // Reduce virial values in the thread block
186 typedef cub::BlockReduce<double, 128> BlockReduce;
187 __shared__ typename BlockReduce::TempStorage temp_storage;
189 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
191 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
193 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
196 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
198 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
200 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
203 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
205 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
207 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
211 if(threadIdx.x == 0) {
213 // thread 0 adds the reduced virial values into device memory
214 atomicAdd(&(d_virial->xx), r_virial.xx);
215 atomicAdd(&(d_virial->xy), r_virial.xy);
216 atomicAdd(&(d_virial->xz), r_virial.xz);
218 atomicAdd(&(d_virial->yx), r_virial.yx);
219 atomicAdd(&(d_virial->yy), r_virial.yy);
220 atomicAdd(&(d_virial->yz), r_virial.yz);
222 atomicAdd(&(d_virial->zx), r_virial.zx);
223 atomicAdd(&(d_virial->zy), r_virial.zy);
224 atomicAdd(&(d_virial->zz), r_virial.zz);
227 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
228 isLastBlockDone = (value == (totaltb -1));
233 if(isLastBlockDone) {
234 // Thread 0 of the last block will set the host values
235 if(threadIdx.x == 0) {
237 h_resEnergy[0] = energy; // restraint energy for each group, needed for output
238 h_diffCOM->x = diffCOM.x; // distance between COM of two restrained groups
239 h_diffCOM->y = diffCOM.y; // distance between COM of two restrained groups
240 h_diffCOM->z = diffCOM.z; // distance between COM of two restrained groups
241 h_resForce->x = group_f.x; // restraint force on group 2
242 h_resForce->y = group_f.y; // restraint force on group 2
243 h_resForce->z = group_f.z; // restraint force on group 2
246 // Add virial values to host memory.
247 // We use add,since we have with multiple restraints group
248 h_extVirial->xx += d_virial->xx;
249 h_extVirial->xy += d_virial->xy;
250 h_extVirial->xz += d_virial->xz;
251 h_extVirial->yx += d_virial->yx;
252 h_extVirial->yy += d_virial->yy;
253 h_extVirial->yz += d_virial->yz;
254 h_extVirial->zx += d_virial->zx;
255 h_extVirial->zy += d_virial->zy;
256 h_extVirial->zz += d_virial->zz;
258 //reset the device virial value
271 //resets atomic counter
278 {// need lastBlockDone for MGPU
279 if(threadIdx.x == 0){
281 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
282 isLastBlockDone = (value == (totaltb -1));
288 if(threadIdx.x == 0){
289 // zero out for next iteration
296 //resets atomic counter
305 /*! Compute restraint force, virial, and energy applied to small
306 groups (atoms < 1024), due to restraining COM of group 2
307 (h_group2COM) to the COM of the group 1 (h_group1COM).
308 This function also calculates the distance from COM of the
309 group 1 to COM of the group 2. */
310 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
311 __global__ void computeSmallGroupRestraint2GroupsKernel(
312 const int numRestrainedGroup1,
313 const int totalNumRestrained,
314 const int restraintExp,
315 const double restraintK,
316 const double3 resCenterVec,
317 const double3 resDirection,
318 const double inv_group1_mass,
319 const double inv_group2_mass,
320 const int* __restrict groupAtomsSOAIndex,
322 const char3* __restrict transform,
323 const float* __restrict mass,
324 const double* __restrict pos_x,
325 const double* __restrict pos_y,
326 const double* __restrict pos_z,
327 double* __restrict f_normal_x,
328 double* __restrict f_normal_y,
329 double* __restrict f_normal_z,
330 cudaTensor* __restrict h_extVirial,
331 double* __restrict h_resEnergy,
332 double3* __restrict h_resForce,
333 double3* __restrict h_diffCOM,
334 double3* __restrict group1COM, // on device in Multi GPU
335 double3* __restrict group2COM,
336 unsigned int* __restrict d_tbcatomic)
338 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
339 __shared__ double3 sh_com1;
340 __shared__ double3 sh_com2;
341 bool isLastBlockDone = false;
342 int totaltb = gridDim.x;
345 double3 com1 = {0, 0, 0};
346 double3 com2 = {0, 0, 0};
347 double3 diffCOM = {0, 0, 0};
348 double3 group_f = {0, 0, 0};
349 double3 pos = {0, 0, 0};
350 double3 f = {0, 0, 0};
352 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
353 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
354 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
358 com1.x = group1COM->x;
359 com1.y = group1COM->y;
360 com1.z = group1COM->z;
361 com2.x = group2COM->x;
362 com2.y = group2COM->y;
363 com2.z = group2COM->z;
365 if(tIdx < totalNumRestrained){
366 // First -> recalculate center of mass.
367 SOAindex = groupAtomsSOAIndex[tIdx];
369 m = mass[SOAindex]; // Cast from float to double here
370 pos.x = pos_x[SOAindex];
371 pos.y = pos_y[SOAindex];
372 pos.z = pos_z[SOAindex];
374 // unwrap the coordinate to calculate COM
375 char3 tr = transform[SOAindex];
376 pos = lat.reverse_transform(pos, tr);
378 if (tIdx < numRestrainedGroup1) {
379 // we initialized the com2 to zero
384 // we initialized the com1 to zero
392 // reduce the (mass * position) values for group 1 and 2 in the thread block
393 typedef cub::BlockReduce<double, 1024> BlockReduce;
394 __shared__ typename BlockReduce::TempStorage temp_storage;
396 com1.x = BlockReduce(temp_storage).Sum(com1.x);
398 com1.y = BlockReduce(temp_storage).Sum(com1.y);
400 com1.z = BlockReduce(temp_storage).Sum(com1.z);
402 com2.x = BlockReduce(temp_storage).Sum(com2.x);
404 com2.y = BlockReduce(temp_storage).Sum(com2.y);
406 com2.z = BlockReduce(temp_storage).Sum(com2.z);
409 // Thread 0 calculates the COM of group 1 and 2
410 if(threadIdx.x == 0){
411 sh_com1.x = com1.x * inv_group1_mass; // calculates the COM of group 1
412 sh_com1.y = com1.y * inv_group1_mass; // calculates the COM of group 1
413 sh_com1.z = com1.z * inv_group1_mass; // calculates the COM of group 1
414 sh_com2.x = com2.x * inv_group2_mass; // calculates the COM of group 2
415 sh_com2.y = com2.y * inv_group2_mass; // calculates the COM of group 2
416 sh_com2.z = com2.z * inv_group2_mass; // calculates the COM of group 2
420 if(tIdx < totalNumRestrained) {
421 // Here for consistency with distanceZ, we calculate
422 // distance from com1 to com2 along specific restraint dimention,
423 // so force is acting on group 2
424 diffCOM.x = (sh_com2.x - sh_com1.x) * resDirection.x;
425 diffCOM.y = (sh_com2.y - sh_com1.y) * resDirection.y;
426 diffCOM.z = (sh_com2.z - sh_com1.z) * resDirection.z;
427 // Calculate the minimum image distance
428 diffCOM = lat.delta_from_diff(diffCOM);
430 if (T_USEMAGNITUDE) {
431 // Calculate the difference from equilibrium restraint distance
432 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
433 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
434 resCenterVec.z*resCenterVec.z);
436 double distDiff = (comVal - centerVal);
437 double distSqDiff = distDiff * distDiff;
438 double invCOMVal = 1.0 / comVal;
440 // Calculate energy and force on group of atoms
441 if(distSqDiff > 0.0f) { // To avoid numerical error
442 // Energy = k * (r - r_eq)^n
443 energy = restraintK * distSqDiff;
444 for (int n = 2; n < restraintExp; n += 2) {
445 energy *= distSqDiff;
447 // Force = -k * n * (r - r_eq)^(n-1)
448 double force = -energy * restraintExp / distDiff;
449 // calculate force along COM difference
450 group_f.x = force * diffCOM.x * invCOMVal;
451 group_f.y = force * diffCOM.y * invCOMVal;
452 group_f.z = force * diffCOM.z * invCOMVal;
455 // Calculate the difference from equilibrium restraint distance vector
456 // along specific restraint dimention
458 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
459 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
460 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
461 // Wrap the distance difference (diffCOM - resCenterVec)
462 resDist = lat.delta_from_diff(resDist);
464 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
466 // Calculate energy and force on group of atoms
467 if(distSqDiff > 0.0f) { // To avoid numerical error
468 // Energy = k * (r - r_eq)^n
469 energy = restraintK * distSqDiff;
470 for (int n = 2; n < restraintExp; n += 2) {
471 energy *= distSqDiff;
473 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
474 double force = -energy * restraintExp / distSqDiff;
475 group_f.x = force * resDist.x;
476 group_f.y = force * resDist.y;
477 group_f.z = force * resDist.z;
481 // calculate the force on each atom of the group
482 if (tIdx < numRestrainedGroup1) {
483 // threads [0 , numGroup1Atoms) calculate force for group 1
484 // We use negative because force is calculated for group 2
485 f.x = -group_f.x * m * inv_group1_mass;
486 f.y = -group_f.y * m * inv_group1_mass;
487 f.z = -group_f.z * m * inv_group1_mass;
489 // threads [numGroup1Atoms , totalNumRestrained) calculate force for group 2
490 f.x = group_f.x * m * inv_group2_mass;
491 f.y = group_f.y * m * inv_group2_mass;
492 f.z = group_f.z * m * inv_group2_mass;
495 // apply the bias to each atom in group
496 f_normal_x[SOAindex] += f.x ;
497 f_normal_y[SOAindex] += f.y ;
498 f_normal_z[SOAindex] += f.z ;
499 // Virial is based on applied force on each atom
501 // positions must be unwraped for virial calculation
502 r_virial.xx = f.x * pos.x;
503 r_virial.xy = f.x * pos.y;
504 r_virial.xz = f.x * pos.z;
505 r_virial.yx = f.y * pos.x;
506 r_virial.yy = f.y * pos.y;
507 r_virial.yz = f.y * pos.z;
508 r_virial.zx = f.z * pos.x;
509 r_virial.zy = f.z * pos.y;
510 r_virial.zz = f.z * pos.z;
515 if(T_DOENERGY || T_DOVIRIAL) {
517 // Reduce virial values in the thread block
518 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
520 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
522 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
525 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
527 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
529 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
532 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
534 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
536 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
540 // thread zero updates the restraints energy and force
541 if(threadIdx.x == 0){
543 // Add virial values to host memory.
544 // We use add,since we have with multiple restraints group
545 h_extVirial->xx += r_virial.xx;
546 h_extVirial->xy += r_virial.xy;
547 h_extVirial->xz += r_virial.xz;
548 h_extVirial->yx += r_virial.yx;
549 h_extVirial->yy += r_virial.yy;
550 h_extVirial->yz += r_virial.yz;
551 h_extVirial->zx += r_virial.zx;
552 h_extVirial->zy += r_virial.zy;
553 h_extVirial->zz += r_virial.zz;
556 h_resEnergy[0] = energy; // restraint energy for each group, needed for output
557 h_diffCOM->x = diffCOM.x; // distance between two COM of restrained groups
558 h_diffCOM->y = diffCOM.y; // distance between two COM of restrained groups
559 h_diffCOM->z = diffCOM.z; // distance between two COM of restrained groups
560 h_resForce->x = group_f.x; // restraint force on group
561 h_resForce->y = group_f.y; // restraint force on group
562 h_resForce->z = group_f.z; // restraint force on group
567 {// need lastBockDone for T_MGPUON
568 if(threadIdx.x == 0){
570 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
571 isLastBlockDone = (value == (totaltb -1));
578 { // zero out for next iteration
586 //resets atomic counter
593 /*! Compute restraint force, energy, and virial
594 applied to group 2, due to restraining COM of
595 group 2 to the COM of group 1 */
596 void computeGroupRestraint_2Group(
598 const int useMagnitude,
601 const int numRestrainedGroup1,
602 const int totalNumRestrained,
603 const int restraintExp,
604 const double restraintK,
605 const double3 resCenterVec,
606 const double3 resDirection,
607 const double inv_group1_mass,
608 const double inv_group2_mass,
609 const int* d_groupAtomsSOAIndex,
611 const char3* d_transform,
613 const double* d_pos_x,
614 const double* d_pos_y,
615 const double* d_pos_z,
616 double* d_f_normal_x,
617 double* d_f_normal_y,
618 double* d_f_normal_z,
619 cudaTensor* d_virial,
620 cudaTensor* h_extVirial,
623 double3* h_group1COM,
624 double3* h_group2COM,
626 double3* d_group1COM,
627 double3* d_group2COM,
628 double3** d_peer1COM,
629 double3** d_peer2COM,
630 unsigned int* d_tbcatomic,
631 const int numDevices,
634 int options = doEnergy + (doVirial << 1) + (useMagnitude << 2)
636 double3* COM1Ptr=(mGpuOn) ? d_group1COM: h_group1COM;
637 double3* COM2Ptr=(mGpuOn) ? d_group2COM: h_group2COM;
638 const int blocks = (totalNumRestrained > 1024) ? 128 : 1024;
639 const int grid = (totalNumRestrained > 1024) ? (totalNumRestrained + blocks - 1) / blocks : 1;
640 if (totalNumRestrained > 1024) {
641 // first calculate the COM for restraint groups and store it in
642 // h_group1COM and h_group2COM
643 if(!mGpuOn) // if we don't have distributed COM
645 compute2COMKernel<128><<<grid, blocks, 0, stream>>>(
656 d_groupAtomsSOAIndex,
664 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
667 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
671 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
672 computeLargeGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
673 <<<grid, blocks, 0, stream>>>( \
674 numRestrainedGroup1, totalNumRestrained, \
675 restraintExp, restraintK, resCenterVec, resDirection, \
676 inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
677 lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
678 d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
679 h_extVirial, h_resEnergy, h_resForce, COM1Ptr, \
680 COM2Ptr, h_diffCOM, d_tbcatomic);
682 case 0: CALL_LARGE_GROUP_RES(0, 0, 0, 0); break;
683 case 1: CALL_LARGE_GROUP_RES(1, 0, 0, 0); break;
684 case 2: CALL_LARGE_GROUP_RES(0, 1, 0, 0); break;
685 case 3: CALL_LARGE_GROUP_RES(1, 1, 0, 0); break;
686 case 4: CALL_LARGE_GROUP_RES(0, 0, 1, 0); break;
687 case 5: CALL_LARGE_GROUP_RES(1, 0, 1, 0); break;
688 case 6: CALL_LARGE_GROUP_RES(0, 1, 1, 0); break;
689 case 7: CALL_LARGE_GROUP_RES(1, 1, 1, 0); break;
690 case 8: CALL_LARGE_GROUP_RES(0, 0, 0, 1); break;
691 case 9: CALL_LARGE_GROUP_RES(1, 0, 0, 1); break;
692 case 10: CALL_LARGE_GROUP_RES(0, 1, 0, 1); break;
693 case 11: CALL_LARGE_GROUP_RES(1, 1, 0, 1); break;
694 case 12: CALL_LARGE_GROUP_RES(0, 0, 1, 1); break;
695 case 13: CALL_LARGE_GROUP_RES(1, 0, 1, 1); break;
696 case 14: CALL_LARGE_GROUP_RES(0, 1, 1, 1); break;
697 case 15: CALL_LARGE_GROUP_RES(1, 1, 1, 1); break;
699 #undef CALL_LARGE_GROUP_RES
701 // For small group of restrained atom, we can just launch
702 // a single threadblock
704 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer1COM,
707 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
711 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
712 computeSmallGroupRestraint2GroupsKernel<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
713 <<<grid, blocks, 0, stream>>>( \
714 numRestrainedGroup1, totalNumRestrained, \
715 restraintExp, restraintK, resCenterVec, resDirection, \
716 inv_group1_mass, inv_group2_mass, d_groupAtomsSOAIndex, \
717 lat, d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
718 d_f_normal_x, d_f_normal_y, d_f_normal_z, \
719 h_extVirial, h_resEnergy, h_resForce, h_diffCOM, \
720 COM1Ptr, COM2Ptr, d_tbcatomic \
723 case 0: CALL_SMALL_GROUP_RES(0, 0, 0, 0); break;
724 case 1: CALL_SMALL_GROUP_RES(1, 0, 0, 0); break;
725 case 2: CALL_SMALL_GROUP_RES(0, 1, 0, 0); break;
726 case 3: CALL_SMALL_GROUP_RES(1, 1, 0, 0); break;
727 case 4: CALL_SMALL_GROUP_RES(0, 0, 1, 0); break;
728 case 5: CALL_SMALL_GROUP_RES(1, 0, 1, 0); break;
729 case 6: CALL_SMALL_GROUP_RES(0, 1, 1, 0); break;
730 case 7: CALL_SMALL_GROUP_RES(1, 1, 1, 0); break;
731 case 8: CALL_SMALL_GROUP_RES(0, 0, 0, 1); break;
732 case 9: CALL_SMALL_GROUP_RES(1, 0, 0, 1); break;
733 case 10: CALL_SMALL_GROUP_RES(0, 1, 0, 1); break;
734 case 11: CALL_SMALL_GROUP_RES(1, 1, 0, 1); break;
735 case 12: CALL_SMALL_GROUP_RES(0, 0, 1, 1); break;
736 case 13: CALL_SMALL_GROUP_RES(1, 0, 1, 1); break;
737 case 14: CALL_SMALL_GROUP_RES(0, 1, 1, 1); break;
738 case 15: CALL_SMALL_GROUP_RES(1, 1, 1, 1); break;
740 #undef CALL_SMALL_GROUP_RES
745 #endif // NODEGROUP_FORCE_REGISTER