2 #if __CUDACC_VER_MAJOR__ >= 11
5 #include <namd_cub/cub.cuh>
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
15 #include "ComputeGroupRes1GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
19 #ifdef NODEGROUP_FORCE_REGISTER
21 /*! Compute restraint force, virial, and energy applied to large
22 group 2 (atoms >= 1024), due to restraining COM of group 2
23 (h_group2COM) to a reference point (h_group1COMRef).
24 To use this function, the COM of the group 2 must be calculated
25 and passed to this function as h_group2COM.
26 This function also calculates the distance from ref point to
27 COM of the group 2. */
28 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
29 __global__ void computeLargeGroupRestraintKernel_1Group(
30 const int numRestrainedGroup,
31 const int restraintExp,
32 const double restraintK,
33 const double3 resCenterVec,
34 const double3 resDirection,
35 const double inv_group2_mass,
36 const int* __restrict groupAtomsSOAIndex,
38 const char3* __restrict transform,
39 const float* __restrict mass,
40 const double* __restrict pos_x,
41 const double* __restrict pos_y,
42 const double* __restrict pos_z,
43 double* __restrict f_normal_x,
44 double* __restrict f_normal_y,
45 double* __restrict f_normal_z,
46 cudaTensor* __restrict d_virial,
47 cudaTensor* __restrict h_extVirial,
48 double* __restrict h_resEnergy,
49 double3* __restrict h_resForce,
50 const double3* __restrict h_group1COMRef,
51 double3* __restrict group2COM,
52 double3* __restrict h_diffCOM,
53 unsigned int* __restrict d_tbcatomic)
55 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
56 int totaltb = gridDim.x;
57 bool isLastBlockDone = false;
62 double3 diffCOM = {0, 0, 0};
63 double3 group_f = {0, 0, 0};
64 double3 pos = {0, 0, 0};
65 double3 f = {0, 0, 0};
66 double3 com = {group2COM->x,group2COM->y,group2COM->z};
68 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
69 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
70 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
73 com.x *= inv_group2_mass;
74 com.y *= inv_group2_mass;
75 com.z *= inv_group2_mass;
77 if(tIdx < numRestrainedGroup) {
78 SOAindex = groupAtomsSOAIndex[tIdx];
80 // Calculate distance from ref to com2 along specific restraint dimention
81 diffCOM.x = (com.x - h_group1COMRef->x) * resDirection.x;
82 diffCOM.y = (com.y - h_group1COMRef->y) * resDirection.y;
83 diffCOM.z = (com.z - h_group1COMRef->z) * resDirection.z;
84 // Calculate the minimum image distance
85 diffCOM = lat.delta_from_diff(diffCOM);
88 // Calculate the difference from equilibrium restraint distance
89 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
90 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
91 resCenterVec.z*resCenterVec.z);
93 double distDiff = (comVal - centerVal);
94 double distSqDiff = distDiff * distDiff;
95 double invCOMVal = 1.0 / comVal;
97 // Calculate energy and force on group of atoms
98 if(distSqDiff > 0.0f) { // To avoid numerical error
99 // Energy = k * (r - r_eq)^n
100 energy = restraintK * distSqDiff;
101 for (int n = 2; n < restraintExp; n += 2) {
102 energy *= distSqDiff;
104 // Force = -k * n * (r - r_eq)^(n-1)
105 double force = -energy * restraintExp / distDiff;
106 // calculate force along COM difference
107 group_f.x = force * diffCOM.x * invCOMVal;
108 group_f.y = force * diffCOM.y * invCOMVal;
109 group_f.z = force * diffCOM.z * invCOMVal;
112 // Calculate the difference from equilibrium restraint distance vector
113 // along specific restraint dimention
115 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
116 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
117 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
118 // Wrap the distance difference (diffCOM - resCenterVec)
119 resDist = lat.delta_from_diff(resDist);
121 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
123 // Calculate energy and force on group of atoms
124 if(distSqDiff > 0.0f) { // To avoid numerical error
125 // Energy = k * (r - r_eq)^n
126 energy = restraintK * distSqDiff;
127 for (int n = 2; n < restraintExp; n += 2) {
128 energy *= distSqDiff;
130 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
131 double force = -energy * restraintExp / distSqDiff;
132 group_f.x = force * resDist.x;
133 group_f.y = force * resDist.y;
134 group_f.z = force * resDist.z;
138 // calculate the force on each atom
139 f.x = group_f.x * m * inv_group2_mass;
140 f.y = group_f.y * m * inv_group2_mass;
141 f.z = group_f.z * m * inv_group2_mass;
142 // apply the bias to each atom in group
143 f_normal_x[SOAindex] += f.x;
144 f_normal_y[SOAindex] += f.y;
145 f_normal_z[SOAindex] += f.z;
146 // Virial is based on applied force on each atom
148 // positions must be unwraped for virial calculation
149 pos.x = pos_x[SOAindex];
150 pos.y = pos_y[SOAindex];
151 pos.z = pos_z[SOAindex];
152 char3 tr = transform[SOAindex];
153 pos = lat.reverse_transform(pos, tr);
154 r_virial.xx = f.x * pos.x;
155 r_virial.xy = f.x * pos.y;
156 r_virial.xz = f.x * pos.z;
157 r_virial.yx = f.y * pos.x;
158 r_virial.yy = f.y * pos.y;
159 r_virial.yz = f.y * pos.z;
160 r_virial.zx = f.z * pos.x;
161 r_virial.zy = f.z * pos.y;
162 r_virial.zz = f.z * pos.z;
167 if(T_DOENERGY || T_DOVIRIAL) {
169 // Reduce virial values in the thread block
170 typedef cub::BlockReduce<double, 128> BlockReduce;
171 __shared__ typename BlockReduce::TempStorage temp_storage;
172 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
174 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
176 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
179 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
181 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
183 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
186 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
188 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
190 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
194 if(threadIdx.x == 0) {
196 // thread 0 adds the reduced virial values into device memory
197 atomicAdd(&(d_virial->xx), r_virial.xx);
198 atomicAdd(&(d_virial->xy), r_virial.xy);
199 atomicAdd(&(d_virial->xz), r_virial.xz);
201 atomicAdd(&(d_virial->yx), r_virial.yx);
202 atomicAdd(&(d_virial->yy), r_virial.yy);
203 atomicAdd(&(d_virial->yz), r_virial.yz);
205 atomicAdd(&(d_virial->zx), r_virial.zx);
206 atomicAdd(&(d_virial->zy), r_virial.zy);
207 atomicAdd(&(d_virial->zz), r_virial.zz);
210 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
211 isLastBlockDone = (value == (totaltb -1));
216 if(isLastBlockDone) {
217 // Thread 0 of the last block will set the host values
218 if(threadIdx.x == 0) {
220 h_resEnergy[0] = energy; // restraint energy
221 h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
222 h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
223 h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
224 h_resForce->x = group_f.x; // restraint force on group 2
225 h_resForce->y = group_f.y; // restraint force on group 2
226 h_resForce->z = group_f.z; // restraint force on group 2
229 // Add virial values to host memory.
230 // We use add,since we have with multiple restraints group
231 h_extVirial->xx += d_virial->xx;
232 h_extVirial->xy += d_virial->xy;
233 h_extVirial->xz += d_virial->xz;
234 h_extVirial->yx += d_virial->yx;
235 h_extVirial->yy += d_virial->yy;
236 h_extVirial->yz += d_virial->yz;
237 h_extVirial->zx += d_virial->zx;
238 h_extVirial->zy += d_virial->zy;
239 h_extVirial->zz += d_virial->zz;
241 //reset the device virial values
254 //resets atomic counter
261 {// need lastBockDone for T_MGPUON
265 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
266 isLastBlockDone = (value == (totaltb -1));
269 // last block cleans up
272 if(threadIdx.x == 0){
273 // zero out for next iteration
277 h_diffCOM->x = diffCOM.x;
278 h_diffCOM->y = diffCOM.y;
279 h_diffCOM->z = diffCOM.z;
280 //resets atomic counter
289 /*! Compute restraint force, virial, and energy applied to small
290 group 2 (atoms < 1024), due to restraining COM of group 2
291 (h_group2COM) to a reference point (h_group1COMRef).
292 This function also calculates the distance from ref point to
293 COM of the group 2. */
294 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE, int T_MGPUON>
295 __global__ void computeSmallGroupRestraintKernel_1Group(
296 const int numRestrainedGroup,
297 const int restraintExp,
298 const double restraintK,
299 const double3 resCenterVec,
300 const double3 resDirection,
301 const double inv_group2_mass,
302 const int* __restrict groupAtomsSOAIndex,
304 const char3* __restrict transform,
305 const float* __restrict mass,
306 const double* __restrict pos_x,
307 const double* __restrict pos_y,
308 const double* __restrict pos_z,
309 double* __restrict f_normal_x,
310 double* __restrict f_normal_y,
311 double* __restrict f_normal_z,
312 cudaTensor* __restrict h_extVirial,
313 double* __restrict h_resEnergy,
314 double3* __restrict h_resForce,
315 const double3* __restrict h_group1COMRef,
316 double3* __restrict group2COM, // device in T_MGPUON case
317 double3* __restrict h_diffCOM,
318 unsigned int* __restrict d_tbcatomic)
320 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
321 __shared__ double3 sh_com2;
322 int totaltb = gridDim.x;
323 bool isLastBlockDone = false;
326 double3 com2 = {0, 0, 0};
327 double3 diffCOM = {0, 0, 0};
328 double3 group_f = {0, 0, 0};
329 double3 pos = {0, 0, 0};
330 double3 f = {0, 0, 0};
332 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
333 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
334 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
337 if(tIdx < numRestrainedGroup){
338 // First -> recalculate center of mass.
339 SOAindex = groupAtomsSOAIndex[tIdx];
341 m = mass[SOAindex]; // Cast from float to double here
342 pos.x = pos_x[SOAindex];
343 pos.y = pos_y[SOAindex];
344 pos.z = pos_z[SOAindex];
346 // unwrap the coordinate to calculate COM
347 char3 tr = transform[SOAindex];
348 pos = lat.reverse_transform(pos, tr);
356 // reduce the (mass * position) values for the thread block
357 typedef cub::BlockReduce<double, 1024> BlockReduce;
358 __shared__ typename BlockReduce::TempStorage temp_storage;
360 com2.x = BlockReduce(temp_storage).Sum(com2.x);
362 com2.y = BlockReduce(temp_storage).Sum(com2.y);
364 com2.z = BlockReduce(temp_storage).Sum(com2.z);
367 // Thread 0 calculates the COM
368 if(threadIdx.x == 0){
370 sh_com2.x = group2COM->x * inv_group2_mass;
371 sh_com2.y = group2COM->y * inv_group2_mass;
372 sh_com2.z = group2COM->z * inv_group2_mass;
375 sh_com2.x = com2.x * inv_group2_mass; // calculates the current center of mass
376 sh_com2.y = com2.y * inv_group2_mass; // calculates the current center of mass
377 sh_com2.z = com2.z * inv_group2_mass; // calculates the current center of mass
382 if(tIdx < numRestrainedGroup){
383 // calculate the distance from ref to com2 along specific restraint dimention
384 diffCOM.x = (sh_com2.x - h_group1COMRef->x) * resDirection.x;
385 diffCOM.y = (sh_com2.y - h_group1COMRef->y) * resDirection.y;
386 diffCOM.z = (sh_com2.z - h_group1COMRef->z) * resDirection.z;
387 // Calculate the minimum image distance
388 diffCOM = lat.delta_from_diff(diffCOM);
390 if (T_USEMAGNITUDE) {
391 // Calculate the difference from equilibrium restraint distance
392 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
393 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
394 resCenterVec.z*resCenterVec.z);
396 double distDiff = (comVal - centerVal);
397 double distSqDiff = distDiff * distDiff;
398 double invCOMVal = 1.0 / comVal;
400 // Calculate energy and force on group of atoms
401 if(distSqDiff > 0.0f) { // To avoid numerical error
402 // Energy = k * (r - r_eq)^n
403 energy = restraintK * distSqDiff;
404 for (int n = 2; n < restraintExp; n += 2) {
405 energy *= distSqDiff;
407 // Force = -k * n * (r - r_eq)^(n-1)
408 double force = -energy * restraintExp / distDiff;
409 // calculate force along COM difference
410 group_f.x = force * diffCOM.x * invCOMVal;
411 group_f.y = force * diffCOM.y * invCOMVal;
412 group_f.z = force * diffCOM.z * invCOMVal;
415 // Calculate the difference from equilibrium restraint distance vector
416 // along specific restraint dimention
418 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
419 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
420 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
421 // Wrap the distance difference (diffCOM - resCenterVec)
422 resDist = lat.delta_from_diff(resDist);
424 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
426 // Calculate energy and force on group of atoms
427 if(distSqDiff > 0.0f) { // To avoid numerical error
428 // Energy = k * (r - r_eq)^n
429 energy = restraintK * distSqDiff;
430 for (int n = 2; n < restraintExp; n += 2) {
431 energy *= distSqDiff;
433 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
434 double force = -energy * restraintExp / distSqDiff;
435 group_f.x = force * resDist.x;
436 group_f.y = force * resDist.y;
437 group_f.z = force * resDist.z;
441 // calculate the force on each atom
442 f.x = group_f.x * m * inv_group2_mass;
443 f.y = group_f.y * m * inv_group2_mass;
444 f.z = group_f.z * m * inv_group2_mass;
445 // apply the bias to each atom in group
446 f_normal_x[SOAindex] += f.x ;
447 f_normal_y[SOAindex] += f.y ;
448 f_normal_z[SOAindex] += f.z ;
449 // Virial is based on applied force on each atom
451 // positions must be unwraped for virial calculation
452 r_virial.xx = f.x * pos.x;
453 r_virial.xy = f.x * pos.y;
454 r_virial.xz = f.x * pos.z;
455 r_virial.yx = f.y * pos.x;
456 r_virial.yy = f.y * pos.y;
457 r_virial.yz = f.y * pos.z;
458 r_virial.zx = f.z * pos.x;
459 r_virial.zy = f.z * pos.y;
460 r_virial.zz = f.z * pos.z;
465 if(T_DOENERGY || T_DOVIRIAL) {
467 // Reduce virial values in the thread block
468 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
470 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
472 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
475 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
477 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
479 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
482 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
484 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
486 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
490 // thread zero updates the restraints energy and force
491 if(threadIdx.x == 0){
493 // Add virial values to host memory.
494 // We use add,since we have with multiple restraints group
495 h_extVirial->xx += r_virial.xx;
496 h_extVirial->xy += r_virial.xy;
497 h_extVirial->xz += r_virial.xz;
498 h_extVirial->yx += r_virial.yx;
499 h_extVirial->yy += r_virial.yy;
500 h_extVirial->yz += r_virial.yz;
501 h_extVirial->zx += r_virial.zx;
502 h_extVirial->zy += r_virial.zy;
503 h_extVirial->zz += r_virial.zz;
506 h_resEnergy[0] = energy; // restraint energy
507 h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
508 h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
509 h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
510 h_resForce->x = group_f.x; // restraint force on group 2
511 h_resForce->y = group_f.y; // restraint force on group 2
512 h_resForce->z = group_f.z; // restraint force on group 2
516 if(threadIdx.x == 0){
518 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
519 isLastBlockDone = (value == (totaltb -1));
523 if(threadIdx.x == 0){
524 // zero out for next iteration
528 //resets atomic counter
536 //standalone kernel caller for debugging
539 const int numRestrainedGroup,
542 const int numDevices,
545 const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
546 const int grid = (numRestrainedGroup + blocks - 1) / blocks;
547 computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
552 // for debugging, copy from device for output
555 const int numRestrainedGroup,
558 const int deviceIndex,
561 const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
562 const int grid = (numRestrainedGroup + blocks - 1) / blocks;
563 copyDistCOMKernel<<<grid, blocks, 0, stream>>>(d_peerCOM,
568 // standalone kernel caller for debugging single gpu version
571 const int numRestrainedGroup,
572 const double inv_group2_mass,
573 const int* d_groupAtomsSOAIndex,
575 const char3* d_transform,
577 const double* d_pos_x,
578 const double* d_pos_y,
579 const double* d_pos_z,
580 double3* d_group2COM,
581 double3* h_group2COM,
582 unsigned int* d_tbcatomic,
585 const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
586 const int grid = (numRestrainedGroup + blocks - 1) / blocks;
587 computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
596 d_groupAtomsSOAIndex,
602 /*! Compute restraint force, energy, and virial
603 applied to group 2, due to restraining COM of
604 group 2 to a reference COM position of group 1
606 void computeGroupRestraint_1Group(
608 const int useMagnitude,
611 const int numRestrainedGroup,
612 const int restraintExp,
613 const double restraintK,
614 const double3 resCenterVec,
615 const double3 resDirection,
616 const double inv_group2_mass,
617 const int* d_groupAtomsSOAIndex,
619 const char3* d_transform,
621 const double* d_pos_x,
622 const double* d_pos_y,
623 const double* d_pos_z,
624 double* d_f_normal_x,
625 double* d_f_normal_y,
626 double* d_f_normal_z,
627 cudaTensor* d_virial,
628 cudaTensor* h_extVirial,
631 double3* h_group1COMRef,
632 double3* h_group2COM,
634 double3** d_peer2COM,
635 double3* d_group2COM,
636 unsigned int* d_tbcatomic,
640 int options = doEnergy + (doVirial << 1) + (useMagnitude << 2) + (mGpuOn <<3);
641 const int blocks = (numRestrainedGroup > 1024) ? 128 : 1024;
642 const int grid = (numRestrainedGroup + blocks - 1) / blocks;
643 //multiGPU keeps the COM on the device, not the host
644 double3* COMPtr=(mGpuOn) ? d_group2COM: h_group2COM;
646 if (numRestrainedGroup > 1024) {
647 //first calculate the COM for restraint group and store it in h_group2COM
649 computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
658 d_groupAtomsSOAIndex,
664 // need to compute COM of group2
665 // we only need numDevices threads
666 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
670 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
671 computeLargeGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
672 <<<grid, blocks, 0, stream>>>( \
673 numRestrainedGroup, restraintExp, restraintK, resCenterVec,\
674 resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
675 d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
676 d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
677 h_extVirial, h_resEnergy, h_resForce, h_group1COMRef, \
678 COMPtr, h_diffCOM, d_tbcatomic);
680 // little endian, for unknown reasons
681 case 0: CALL_LARGE_GROUP_RES(0, 0, 0, 0); break;
682 case 1: CALL_LARGE_GROUP_RES(1, 0, 0, 0); break;
683 case 2: CALL_LARGE_GROUP_RES(0, 1, 0, 0); break;
684 case 3: CALL_LARGE_GROUP_RES(1, 1, 0, 0); break;
685 case 4: CALL_LARGE_GROUP_RES(0, 0, 1, 0); break;
686 case 5: CALL_LARGE_GROUP_RES(1, 0, 1, 0); break;
687 case 6: CALL_LARGE_GROUP_RES(0, 1, 1, 0); break;
688 case 7: CALL_LARGE_GROUP_RES(1, 1, 1, 0); break;
689 case 8: CALL_LARGE_GROUP_RES(0, 0, 0, 1); break;
690 case 9: CALL_LARGE_GROUP_RES(1, 0, 0, 1); break;
691 case 10: CALL_LARGE_GROUP_RES(0, 1, 0, 1); break;
692 case 11: CALL_LARGE_GROUP_RES(1, 1, 0, 1); break;
693 case 12: CALL_LARGE_GROUP_RES(0, 0, 1, 1); break;
694 case 13: CALL_LARGE_GROUP_RES(1, 0, 1, 1); break;
695 case 14: CALL_LARGE_GROUP_RES(0, 1, 1, 1); break;
696 case 15: CALL_LARGE_GROUP_RES(1, 1, 1, 1); break;
698 #undef CALL_LARGE_GROUP_RES
700 // For small group of restrained atom, we can just launch
701 // a single threadblock
703 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON) \
704 computeSmallGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE, MGPUON> \
705 <<<grid, blocks, 0, stream>>>( \
706 numRestrainedGroup, restraintExp, restraintK, resCenterVec,\
707 resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
708 d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
709 d_f_normal_x, d_f_normal_y, d_f_normal_z, \
710 h_extVirial, h_resEnergy, h_resForce, \
711 h_group1COMRef, COMPtr, h_diffCOM, d_tbcatomic);
714 // need to compute COM of group2
715 // we only need numDevices threads
716 computeDistCOMKernelMgpu<<<1, numDevices, 0, stream>>>(d_peer2COM,
721 // little endian as above, because bizarro land
722 case 0: CALL_SMALL_GROUP_RES(0, 0, 0, 0); break;
723 case 1: CALL_SMALL_GROUP_RES(1, 0, 0, 0); break;
724 case 2: CALL_SMALL_GROUP_RES(0, 1, 0, 0); break;
725 case 3: CALL_SMALL_GROUP_RES(1, 1, 0, 0); break;
726 case 4: CALL_SMALL_GROUP_RES(0, 0, 1, 0); break;
727 case 5: CALL_SMALL_GROUP_RES(1, 0, 1, 0); break;
728 case 6: CALL_SMALL_GROUP_RES(0, 1, 1, 0); break;
729 case 7: CALL_SMALL_GROUP_RES(1, 1, 1, 0); break;
730 case 8: CALL_SMALL_GROUP_RES(0, 0, 0, 1); break;
731 case 9: CALL_SMALL_GROUP_RES(1, 0, 0, 1); break;
732 case 10: CALL_SMALL_GROUP_RES(0, 1, 0, 1); break;
733 case 11: CALL_SMALL_GROUP_RES(1, 1, 0, 1); break;
734 case 12: CALL_SMALL_GROUP_RES(0, 0, 1, 1); break;
735 case 13: CALL_SMALL_GROUP_RES(1, 0, 1, 1); break;
736 case 14: CALL_SMALL_GROUP_RES(0, 1, 1, 1); break;
737 case 15: CALL_SMALL_GROUP_RES(1, 1, 1, 1); break;
739 #undef CALL_SMALL_GROUP_RES
742 void initPeerCOMmgpuG(
743 const int numDevices,
744 const int deviceIndex,
745 double3** d_peerPool,
749 const int blocks = numDevices;
751 initPeerCOMKernel<<<grid, blocks, 0, stream>>>( numDevices,
757 /* called in earlier phase to handle multi device COM */
762 const double* d_pos_x,
763 const double* d_pos_y,
764 const double* d_pos_z,
765 const char3* d_transform,
766 const int* d_AtomsSOAIndex,
768 double3** d_peer_curCM,
769 unsigned int* d_tbcatomic,
770 const int numDevices,
771 const int deviceIndex,
774 // block it up if large, otherwise all in one go
775 const int blocks = (numAtoms > 1024) ? 128 : 1024;
776 const int grid = (numAtoms + blocks - 1) / blocks;
777 //initialize the device memory to zero here
778 cudaCheck(cudaMemset(d_peerCOM, 0, sizeof(double3)));
779 // why do we need this here?
781 computeCOMKernelMgpu<128><<<grid, blocks, 0, stream>>>(numAtoms,
783 d_pos_x, d_pos_y, d_pos_z,
791 computeCOMKernelMgpu<1024><<<grid, blocks, 0, stream>>>(numAtoms,
793 d_pos_x, d_pos_y, d_pos_z,
802 #endif // NODEGROUP_FORCE_REGISTER