2 #if __CUDACC_VER_MAJOR__ >= 11
5 #include <namd_cub/cub.cuh>
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
15 #include "ComputeGroupRes1GroupCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
19 #ifdef NODEGROUP_FORCE_REGISTER
21 /*! Compute restraint force, virial, and energy applied to large
22 group 2 (atoms >= 1024), due to restraining COM of group 2
23 (h_group2COM) to a reference point (h_group1COMRef).
24 To use this function, the COM of the group 2 must be calculated
25 and passed to this function as h_group2COM.
26 This function also calculates the distance from ref point to
27 COM of the group 2. */
28 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
29 __global__ void computeLargeGroupRestraintKernel_1Group(
30 const int numRestrainedGroup,
31 const int restraintExp,
32 const double restraintK,
33 const double3 resCenterVec,
34 const double3 resDirection,
35 const double inv_group2_mass,
36 const int* __restrict groupAtomsSOAIndex,
38 const char3* __restrict transform,
39 const float* __restrict mass,
40 const double* __restrict pos_x,
41 const double* __restrict pos_y,
42 const double* __restrict pos_z,
43 double* __restrict f_normal_x,
44 double* __restrict f_normal_y,
45 double* __restrict f_normal_z,
46 cudaTensor* __restrict d_virial,
47 cudaTensor* __restrict h_extVirial,
48 double* __restrict h_resEnergy,
49 double3* __restrict h_resForce,
50 const double3* __restrict h_group1COMRef,
51 const double3* __restrict h_group2COM,
52 double3* __restrict h_diffCOM,
53 unsigned int* __restrict d_tbcatomic)
55 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
56 int totaltb = gridDim.x;
57 bool isLastBlockDone = false;
62 double3 diffCOM = {0, 0, 0};
63 double3 group_f = {0, 0, 0};
64 double3 pos = {0, 0, 0};
65 double3 f = {0, 0, 0};
67 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
68 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
69 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
71 if(tIdx < numRestrainedGroup) {
72 SOAindex = groupAtomsSOAIndex[tIdx];
74 // Calculate distance from ref to com2 along specific restraint dimention
75 diffCOM.x = (h_group2COM->x - h_group1COMRef->x) * resDirection.x;
76 diffCOM.y = (h_group2COM->y - h_group1COMRef->y) * resDirection.y;
77 diffCOM.z = (h_group2COM->z - h_group1COMRef->z) * resDirection.z;
78 // Calculate the minimum image distance
79 diffCOM = lat.delta_from_diff(diffCOM);
82 // Calculate the difference from equilibrium restraint distance
83 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
84 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
85 resCenterVec.z*resCenterVec.z);
87 double distDiff = (comVal - centerVal);
88 double distSqDiff = distDiff * distDiff;
89 double invCOMVal = 1.0 / comVal;
91 // Calculate energy and force on group of atoms
92 if(distSqDiff > 0.0f) { // To avoid numerical error
93 // Energy = k * (r - r_eq)^n
94 energy = restraintK * distSqDiff;
95 for (int n = 2; n < restraintExp; n += 2) {
98 // Force = -k * n * (r - r_eq)^(n-1)
99 double force = -energy * restraintExp / distDiff;
100 // calculate force along COM difference
101 group_f.x = force * diffCOM.x * invCOMVal;
102 group_f.y = force * diffCOM.y * invCOMVal;
103 group_f.z = force * diffCOM.z * invCOMVal;
106 // Calculate the difference from equilibrium restraint distance vector
107 // along specific restraint dimention
109 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
110 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
111 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
112 // Wrap the distance difference (diffCOM - resCenterVec)
113 resDist = lat.delta_from_diff(resDist);
115 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
117 // Calculate energy and force on group of atoms
118 if(distSqDiff > 0.0f) { // To avoid numerical error
119 // Energy = k * (r - r_eq)^n
120 energy = restraintK * distSqDiff;
121 for (int n = 2; n < restraintExp; n += 2) {
122 energy *= distSqDiff;
124 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
125 double force = -energy * restraintExp / distSqDiff;
126 group_f.x = force * resDist.x;
127 group_f.y = force * resDist.y;
128 group_f.z = force * resDist.z;
132 // calculate the force on each atom
133 f.x = group_f.x * m * inv_group2_mass;
134 f.y = group_f.y * m * inv_group2_mass;
135 f.z = group_f.z * m * inv_group2_mass;
136 // apply the bias to each atom in group
137 f_normal_x[SOAindex] += f.x;
138 f_normal_y[SOAindex] += f.y;
139 f_normal_z[SOAindex] += f.z;
140 // Virial is based on applied force on each atom
142 // positions must be unwraped for virial calculation
143 pos.x = pos_x[SOAindex];
144 pos.y = pos_y[SOAindex];
145 pos.z = pos_z[SOAindex];
146 char3 tr = transform[SOAindex];
147 pos = lat.reverse_transform(pos, tr);
148 r_virial.xx = f.x * pos.x;
149 r_virial.xy = f.x * pos.y;
150 r_virial.xz = f.x * pos.z;
151 r_virial.yx = f.y * pos.x;
152 r_virial.yy = f.y * pos.y;
153 r_virial.yz = f.y * pos.z;
154 r_virial.zx = f.z * pos.x;
155 r_virial.zy = f.z * pos.y;
156 r_virial.zz = f.z * pos.z;
161 if(T_DOENERGY || T_DOVIRIAL) {
163 // Reduce virial values in the thread block
164 typedef cub::BlockReduce<double, 128> BlockReduce;
165 __shared__ typename BlockReduce::TempStorage temp_storage;
166 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
168 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
170 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
173 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
175 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
177 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
180 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
182 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
184 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
188 if(threadIdx.x == 0) {
190 // thread 0 adds the reduced virial values into device memory
191 atomicAdd(&(d_virial->xx), r_virial.xx);
192 atomicAdd(&(d_virial->xy), r_virial.xy);
193 atomicAdd(&(d_virial->xz), r_virial.xz);
195 atomicAdd(&(d_virial->yx), r_virial.yx);
196 atomicAdd(&(d_virial->yy), r_virial.yy);
197 atomicAdd(&(d_virial->yz), r_virial.yz);
199 atomicAdd(&(d_virial->zx), r_virial.zx);
200 atomicAdd(&(d_virial->zy), r_virial.zy);
201 atomicAdd(&(d_virial->zz), r_virial.zz);
204 unsigned int value = atomicInc(&d_tbcatomic[0], totaltb);
205 isLastBlockDone = (value == (totaltb -1));
210 if(isLastBlockDone) {
211 // Thread 0 of the last block will set the host values
212 if(threadIdx.x == 0) {
214 h_resEnergy[0] = energy; // restraint energy
215 h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
216 h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
217 h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
218 h_resForce->x = group_f.x; // restraint force on group 2
219 h_resForce->y = group_f.y; // restraint force on group 2
220 h_resForce->z = group_f.z; // restraint force on group 2
223 // Add virial values to host memory.
224 // We use add,since we have with multiple restraints group
225 h_extVirial->xx += d_virial->xx;
226 h_extVirial->xy += d_virial->xy;
227 h_extVirial->xz += d_virial->xz;
228 h_extVirial->yx += d_virial->yx;
229 h_extVirial->yy += d_virial->yy;
230 h_extVirial->yz += d_virial->yz;
231 h_extVirial->zx += d_virial->zx;
232 h_extVirial->zy += d_virial->zy;
233 h_extVirial->zz += d_virial->zz;
235 //reset the device virial values
248 //resets atomic counter
257 /*! Compute restraint force, virial, and energy applied to small
258 group 2 (atoms < 1024), due to restraining COM of group 2
259 (h_group2COM) to a reference point (h_group1COMRef).
260 This function also calculates the distance from ref point to
261 COM of the group 2. */
262 template<int T_DOENERGY, int T_DOVIRIAL, int T_USEMAGNITUDE>
263 __global__ void computeSmallGroupRestraintKernel_1Group(
264 const int numRestrainedGroup,
265 const int restraintExp,
266 const double restraintK,
267 const double3 resCenterVec,
268 const double3 resDirection,
269 const double inv_group2_mass,
270 const int* __restrict groupAtomsSOAIndex,
272 const char3* __restrict transform,
273 const float* __restrict mass,
274 const double* __restrict pos_x,
275 const double* __restrict pos_y,
276 const double* __restrict pos_z,
277 double* __restrict f_normal_x,
278 double* __restrict f_normal_y,
279 double* __restrict f_normal_z,
280 cudaTensor* __restrict h_extVirial,
281 double* __restrict h_resEnergy,
282 double3* __restrict h_resForce,
283 const double3* __restrict h_group1COMRef,
284 double3* __restrict h_diffCOM)
286 int tIdx = threadIdx.x + blockIdx.x * blockDim.x;
287 __shared__ double3 sh_com2;
291 double3 com2 = {0, 0, 0};
292 double3 diffCOM = {0, 0, 0};
293 double3 group_f = {0, 0, 0};
294 double3 pos = {0, 0, 0};
295 double3 f = {0, 0, 0};
297 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
298 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
299 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
302 if(tIdx < numRestrainedGroup){
303 // First -> recalculate center of mass.
304 SOAindex = groupAtomsSOAIndex[tIdx];
306 m = mass[SOAindex]; // Cast from float to double here
307 pos.x = pos_x[SOAindex];
308 pos.y = pos_y[SOAindex];
309 pos.z = pos_z[SOAindex];
311 // unwrap the coordinate to calculate COM
312 char3 tr = transform[SOAindex];
313 pos = lat.reverse_transform(pos, tr);
320 // reduce the (mass * position) values for the thread block
321 typedef cub::BlockReduce<double, 1024> BlockReduce;
322 __shared__ typename BlockReduce::TempStorage temp_storage;
324 com2.x = BlockReduce(temp_storage).Sum(com2.x);
326 com2.y = BlockReduce(temp_storage).Sum(com2.y);
328 com2.z = BlockReduce(temp_storage).Sum(com2.z);
331 // Thread 0 calculates the COM
332 if(threadIdx.x == 0){
333 sh_com2.x = com2.x * inv_group2_mass; // calculates the current center of mass
334 sh_com2.y = com2.y * inv_group2_mass; // calculates the current center of mass
335 sh_com2.z = com2.z * inv_group2_mass; // calculates the current center of mass
339 if(tIdx < numRestrainedGroup){
340 // calculate the distance from ref to com2 along specific restraint dimention
341 diffCOM.x = (sh_com2.x - h_group1COMRef->x) * resDirection.x;
342 diffCOM.y = (sh_com2.y - h_group1COMRef->y) * resDirection.y;
343 diffCOM.z = (sh_com2.z - h_group1COMRef->z) * resDirection.z;
344 // Calculate the minimum image distance
345 diffCOM = lat.delta_from_diff(diffCOM);
347 if (T_USEMAGNITUDE) {
348 // Calculate the difference from equilibrium restraint distance
349 double comVal = sqrt(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y + diffCOM.z*diffCOM.z);
350 double centerVal = sqrt(resCenterVec.x*resCenterVec.x + resCenterVec.y*resCenterVec.y +
351 resCenterVec.z*resCenterVec.z);
353 double distDiff = (comVal - centerVal);
354 double distSqDiff = distDiff * distDiff;
355 double invCOMVal = 1.0 / comVal;
357 // Calculate energy and force on group of atoms
358 if(distSqDiff > 0.0f) { // To avoid numerical error
359 // Energy = k * (r - r_eq)^n
360 energy = restraintK * distSqDiff;
361 for (int n = 2; n < restraintExp; n += 2) {
362 energy *= distSqDiff;
364 // Force = -k * n * (r - r_eq)^(n-1)
365 double force = -energy * restraintExp / distDiff;
366 // calculate force along COM difference
367 group_f.x = force * diffCOM.x * invCOMVal;
368 group_f.y = force * diffCOM.y * invCOMVal;
369 group_f.z = force * diffCOM.z * invCOMVal;
372 // Calculate the difference from equilibrium restraint distance vector
373 // along specific restraint dimention
375 resDist.x = (diffCOM.x - resCenterVec.x) * resDirection.x;
376 resDist.y = (diffCOM.y - resCenterVec.y) * resDirection.y;
377 resDist.z = (diffCOM.z - resCenterVec.z) * resDirection.z;
378 // Wrap the distance difference (diffCOM - resCenterVec)
379 resDist = lat.delta_from_diff(resDist);
381 double distSqDiff = resDist.x*resDist.x + resDist.y*resDist.y + resDist.z*resDist.z;
383 // Calculate energy and force on group of atoms
384 if(distSqDiff > 0.0f) { // To avoid numerical error
385 // Energy = k * (r - r_eq)^n
386 energy = restraintK * distSqDiff;
387 for (int n = 2; n < restraintExp; n += 2) {
388 energy *= distSqDiff;
390 // Force = -k * n * (r - r_eq)^(n-1) x (r - r_eq)/|r - r_eq|
391 double force = -energy * restraintExp / distSqDiff;
392 group_f.x = force * resDist.x;
393 group_f.y = force * resDist.y;
394 group_f.z = force * resDist.z;
398 // calculate the force on each atom
399 f.x = group_f.x * m * inv_group2_mass;
400 f.y = group_f.y * m * inv_group2_mass;
401 f.z = group_f.z * m * inv_group2_mass;
402 // apply the bias to each atom in group
403 f_normal_x[SOAindex] += f.x ;
404 f_normal_y[SOAindex] += f.y ;
405 f_normal_z[SOAindex] += f.z ;
406 // Virial is based on applied force on each atom
408 // positions must be unwraped for virial calculation
409 r_virial.xx = f.x * pos.x;
410 r_virial.xy = f.x * pos.y;
411 r_virial.xz = f.x * pos.z;
412 r_virial.yx = f.y * pos.x;
413 r_virial.yy = f.y * pos.y;
414 r_virial.yz = f.y * pos.z;
415 r_virial.zx = f.z * pos.x;
416 r_virial.zy = f.z * pos.y;
417 r_virial.zz = f.z * pos.z;
422 if(T_DOENERGY || T_DOVIRIAL) {
424 // Reduce virial values in the thread block
425 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
427 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
429 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
432 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
434 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
436 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
439 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
441 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
443 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
447 // thread zero updates the restraints energy and force
448 if(threadIdx.x == 0){
450 // Add virial values to host memory.
451 // We use add,since we have with multiple restraints group
452 h_extVirial->xx += r_virial.xx;
453 h_extVirial->xy += r_virial.xy;
454 h_extVirial->xz += r_virial.xz;
455 h_extVirial->yx += r_virial.yx;
456 h_extVirial->yy += r_virial.yy;
457 h_extVirial->yz += r_virial.yz;
458 h_extVirial->zx += r_virial.zx;
459 h_extVirial->zy += r_virial.zy;
460 h_extVirial->zz += r_virial.zz;
463 h_resEnergy[0] = energy; // restraint energy
464 h_diffCOM->x = diffCOM.x; // distance from ref position to COM of group 2
465 h_diffCOM->y = diffCOM.y; // distance from ref position to COM of group 2
466 h_diffCOM->z = diffCOM.z; // distance from ref position to COM of group 2
467 h_resForce->x = group_f.x; // restraint force on group 2
468 h_resForce->y = group_f.y; // restraint force on group 2
469 h_resForce->z = group_f.z; // restraint force on group 2
475 /*! Compute restraint force, energy, and virial
476 applied to group 2, due to restraining COM of
477 group 2 to a reference COM position of group 1
479 void computeGroupRestraint_1Group(
480 const int useMagnitude,
483 const int numRestrainedGroup,
484 const int restraintExp,
485 const double restraintK,
486 const double3 resCenterVec,
487 const double3 resDirection,
488 const double inv_group2_mass,
489 const int* d_groupAtomsSOAIndex,
491 const char3* d_transform,
493 const double* d_pos_x,
494 const double* d_pos_y,
495 const double* d_pos_z,
496 double* d_f_normal_x,
497 double* d_f_normal_y,
498 double* d_f_normal_z,
499 cudaTensor* d_virial,
500 cudaTensor* h_extVirial,
503 double3* h_group1COMRef,
504 double3* h_group2COM,
506 double3* d_group2COM,
507 unsigned int* d_tbcatomic,
510 int options = doEnergy + (doVirial << 1) + (useMagnitude << 2);
512 if (numRestrainedGroup > 1024) {
513 const int blocks = 128;
514 const int grid = (numRestrainedGroup + blocks - 1) / blocks;
515 //first calculate the COM for restraint group and store it in h_group2COM
516 computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
525 d_groupAtomsSOAIndex,
530 #define CALL_LARGE_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
531 computeLargeGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
532 <<<grid, blocks, 0, stream>>>( \
533 numRestrainedGroup, restraintExp, restraintK, resCenterVec, \
534 resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
535 d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
536 d_f_normal_x, d_f_normal_y, d_f_normal_z, d_virial, \
537 h_extVirial, h_resEnergy, h_resForce, h_group1COMRef, \
538 h_group2COM, h_diffCOM, d_tbcatomic);
541 case 0: CALL_LARGE_GROUP_RES(0, 0, 0); break;
542 case 1: CALL_LARGE_GROUP_RES(1, 0, 0); break;
543 case 2: CALL_LARGE_GROUP_RES(0, 1, 0); break;
544 case 3: CALL_LARGE_GROUP_RES(1, 1, 0); break;
545 case 4: CALL_LARGE_GROUP_RES(0, 0, 1); break;
546 case 5: CALL_LARGE_GROUP_RES(1, 0, 1); break;
547 case 6: CALL_LARGE_GROUP_RES(0, 1, 1); break;
548 case 7: CALL_LARGE_GROUP_RES(1, 1, 1); break;
551 #undef CALL_LARGE_GROUP_RES
554 // For small group of restrained atom, we can just launch
555 // a single threadblock
556 const int blocks = 1024;
559 #define CALL_SMALL_GROUP_RES(DOENERGY, DOVIRIAL, USEMAGNITUDE) \
560 computeSmallGroupRestraintKernel_1Group<DOENERGY, DOVIRIAL, USEMAGNITUDE>\
561 <<<grid, blocks, 0, stream>>>( \
562 numRestrainedGroup, restraintExp, restraintK, resCenterVec, \
563 resDirection, inv_group2_mass, d_groupAtomsSOAIndex, lat, \
564 d_transform, d_mass, d_pos_x, d_pos_y, d_pos_z, \
565 d_f_normal_x, d_f_normal_y, d_f_normal_z, \
566 h_extVirial, h_resEnergy, h_resForce, \
567 h_group1COMRef, h_diffCOM);
571 case 0: CALL_SMALL_GROUP_RES(0, 0, 0); break;
572 case 1: CALL_SMALL_GROUP_RES(1, 0, 0); break;
573 case 2: CALL_SMALL_GROUP_RES(0, 1, 0); break;
574 case 3: CALL_SMALL_GROUP_RES(1, 1, 0); break;
575 case 4: CALL_SMALL_GROUP_RES(0, 0, 1); break;
576 case 5: CALL_SMALL_GROUP_RES(1, 0, 1); break;
577 case 6: CALL_SMALL_GROUP_RES(0, 1, 1); break;
578 case 7: CALL_SMALL_GROUP_RES(1, 1, 1); break;
581 #undef CALL_SMALL_GROUP_RES
586 #endif // NODEGROUP_FORCE_REGISTER