2 #if __CUDACC_VER_MAJOR__ >= 11
5 #include <namd_cub/cub.cuh>
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
15 #include "ComputeSMDCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
19 #ifdef NODEGROUP_FORCE_REGISTER
22 /*! Calculate SMD force and virial for large atom group (numSMDAtoms > 1024)
23 Multiple thread block will be called to do this operation.
24 The current COM (curCOM) must be calculated and pssed to this function. */
25 template<bool T_DOENERGY, bool T_MGPUON>
26 __global__ void computeSMDForceWithCOMKernel(
27 const int numSMDAtoms,
29 const double inv_group_mass,
32 const double velocity,
33 const double3 direction,
34 const int currentTime,
36 const float * __restrict mass,
37 const double* __restrict pos_x,
38 const double* __restrict pos_y,
39 const double* __restrict pos_z,
40 const char3* __restrict transform,
41 double* __restrict f_normal_x,
42 double* __restrict f_normal_y,
43 double* __restrict f_normal_z,
44 const int* __restrict smdAtomsSOAIndex,
45 cudaTensor* __restrict d_virial,
46 double3* __restrict h_curCOM,
47 double3* __restrict d_curCOM,
48 double3** __restrict d_peerCOM,
49 double* __restrict h_extEnergy,
50 double3* __restrict h_extForce,
51 cudaTensor* __restrict h_extVirial,
52 unsigned int* __restrict tbcatomic)
54 int tid = threadIdx.x + blockIdx.x * blockDim.x;
55 int totaltb = gridDim.x;
56 bool isLastBlockDone = 0;
57 double3 group_f = {0, 0, 0};
59 double3 pos = {0, 0, 0};
60 double3 f = {0, 0, 0};
62 double3 cm={h_curCOM->x, h_curCOM->y, h_curCOM->z};
63 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
64 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
65 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
70 cm.x = d_curCOM->x * inv_group_mass;
71 cm.y = d_curCOM->y * inv_group_mass;
72 cm.z = d_curCOM->z * inv_group_mass;
75 if(tid < numSMDAtoms){
76 SOAindex = smdAtomsSOAIndex[tid];
78 // uncoalesced memory access: too bad!
79 double m = mass[SOAindex]; // Cast from float to double here
80 pos.x = pos_x[SOAindex];
81 pos.y = pos_y[SOAindex];
82 pos.z = pos_z[SOAindex];
84 // calculate the distance difference along direction
86 diffCOM.x = cm.x - origCM.x;
87 diffCOM.y = cm.y - origCM.y;
88 diffCOM.z = cm.z - origCM.z;
89 double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
90 diffCOM.z*direction.z;
92 // Ok so we've calculated the new center of mass, now we can calculate the bias
93 double preFactor = (velocity*currentTime - diff);
94 group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
95 group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
96 group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
98 // calculate the force on each atom
99 f.x = group_f.x * m * inv_group_mass;
100 f.y = group_f.y * m * inv_group_mass;
101 f.z = group_f.z * m * inv_group_mass;
104 f_normal_x[SOAindex] += f.x ;
105 f_normal_y[SOAindex] += f.y ;
106 f_normal_z[SOAindex] += f.z ;
108 // energy for restraint along the direction
109 energy = 0.5*k*preFactor*preFactor;
110 // energy for transverse restraint
111 energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
112 diffCOM.z*diffCOM.z - diff*diff);
113 // unwrap coordinates before calculating the virial
114 char3 t = transform[SOAindex];
115 pos = lat.reverse_transform(pos, t);
116 r_virial.xx = f.x * pos.x;
117 r_virial.xy = f.x * pos.y;
118 r_virial.xz = f.x * pos.z;
119 r_virial.yx = f.y * pos.x;
120 r_virial.yy = f.y * pos.y;
121 r_virial.yz = f.y * pos.z;
122 r_virial.zx = f.z * pos.x;
123 r_virial.zy = f.z * pos.y;
124 r_virial.zz = f.z * pos.z;
130 typedef cub::BlockReduce<double, 128> BlockReduce;
131 __shared__ typename BlockReduce::TempStorage temp_storage;
133 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
135 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
137 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
140 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
142 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
144 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
147 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
149 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
151 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
154 if(threadIdx.x == 0){
155 atomicAdd(&(d_virial->xx), r_virial.xx);
156 atomicAdd(&(d_virial->xy), r_virial.xy);
157 atomicAdd(&(d_virial->xz), r_virial.xz);
159 atomicAdd(&(d_virial->yx), r_virial.yx);
160 atomicAdd(&(d_virial->yy), r_virial.yy);
161 atomicAdd(&(d_virial->yz), r_virial.yz);
163 atomicAdd(&(d_virial->zx), r_virial.zx);
164 atomicAdd(&(d_virial->zy), r_virial.zy);
165 atomicAdd(&(d_virial->zz), r_virial.zz);
168 unsigned int value = atomicInc(&tbcatomic[0], totaltb);
169 isLastBlockDone = (value == (totaltb -1));
173 // Last block will set the host values
175 if(threadIdx.x == 0){
176 h_extEnergy[0] = energy;
177 h_extForce->x = group_f.x;
178 h_extForce->y = group_f.y;
179 h_extForce->z = group_f.z;
181 h_extVirial->xx = d_virial->xx;
182 h_extVirial->xy = d_virial->xy;
183 h_extVirial->xz = d_virial->xz;
184 h_extVirial->yx = d_virial->yx;
185 h_extVirial->yy = d_virial->yy;
186 h_extVirial->yz = d_virial->yz;
187 h_extVirial->zx = d_virial->zx;
188 h_extVirial->zy = d_virial->zy;
189 h_extVirial->zz = d_virial->zz;
190 //reset the device virial value
206 { // compute isLastBlockDone in the non energy steps
207 if(threadIdx.x == 0){
209 unsigned int value = atomicInc(&tbcatomic[0], totaltb);
210 isLastBlockDone = (value == (totaltb -1));
215 if(threadIdx.x == 0){
224 //resets atomic counter
232 /*! Calculate SMD force, virial, and COM for small atom group (numSMDAtoms <= 1024)
233 Single thread block will be called to do this operation.
234 The current COM will be calculated and stored in h_curCM. */
235 template<bool T_DOENERGY, bool T_MGPUON>
236 __global__ void computeSMDForceKernel(
237 const int numSMDAtoms,
239 const double inv_group_mass,
242 const double velocity,
243 const double3 direction,
244 const int currentTime,
245 const double3 origCM,
246 const float * __restrict mass,
247 const double* __restrict pos_x,
248 const double* __restrict pos_y,
249 const double* __restrict pos_z,
250 const char3* __restrict transform,
251 double* __restrict f_normal_x,
252 double* __restrict f_normal_y,
253 double* __restrict f_normal_z,
254 const int* __restrict smdAtomsSOAIndex,
255 double3* __restrict h_curCM,
256 double3* __restrict d_curCM,
257 double3** __restrict d_peerCOM,
258 double* __restrict h_extEnergy,
259 double3* __restrict h_extForce,
260 cudaTensor* __restrict h_extVirial,
261 unsigned int* __restrict tbcatomic)
263 __shared__ double3 group_f;
264 __shared__ double energy;
265 int tid = threadIdx.x + blockIdx.x * blockDim.x;
266 int totaltb = gridDim.x;
267 bool isLastBlockDone = 0;
269 double3 cm = {0, 0, 0};
270 double3 pos = {0, 0, 0};
271 double3 f = {0, 0, 0};
273 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
274 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
275 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
277 // in the mGpuOn case the COM must be calculated across devices and passed in
278 if(tid < numSMDAtoms){
279 // First -> recalculate center of mass.
280 // Only thread zero is doing this
281 SOAindex = smdAtomsSOAIndex[tid];
282 m = mass[SOAindex]; // Cast from float to double here
283 pos.x = pos_x[SOAindex];
284 pos.y = pos_y[SOAindex];
285 pos.z = pos_z[SOAindex];
287 // unwrap the coordinate to calculate COM
288 char3 t = transform[SOAindex];
289 pos = lat.reverse_transform(pos, t);
297 // now reduce the values and add it to thread zero
298 typedef cub::BlockReduce<double, 1024> BlockReduce;
299 __shared__ typename BlockReduce::TempStorage temp_storage;
301 cm.x = BlockReduce(temp_storage).Sum(cm.x);
303 cm.y = BlockReduce(temp_storage).Sum(cm.y);
305 cm.z = BlockReduce(temp_storage).Sum(cm.z);
308 // Calculate group force and acceleration
309 if(threadIdx.x == 0){
312 cm.x = d_curCM->x * inv_group_mass;
313 cm.y = d_curCM->y * inv_group_mass;
314 cm.z = d_curCM->z * inv_group_mass;
318 cm.x *= inv_group_mass; // calculates the current center of mass
319 cm.y *= inv_group_mass; // calculates the current center of mass
320 cm.z *= inv_group_mass; // calculates the current center of mass
322 // calculate the distance difference along direction
324 diffCOM.x = cm.x - origCM.x;
325 diffCOM.y = cm.y - origCM.y;
326 diffCOM.z = cm.z - origCM.z;
327 double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
328 diffCOM.z*direction.z;
330 // Ok so we've calculated the new center of mass, now we can calculate the bias
331 double preFactor = (velocity*currentTime - diff);
332 group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
333 group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
334 group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
336 // energy for restraint along the direction
337 energy = 0.5*k*preFactor*preFactor;
338 // energy for transverse restraint
339 energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
340 diffCOM.z*diffCOM.z - diff*diff);
345 if(tid < numSMDAtoms){
346 // calculate the force on each atom
347 f.x = group_f.x * m * inv_group_mass;
348 f.y = group_f.y * m * inv_group_mass;
349 f.z = group_f.z * m * inv_group_mass;
352 f_normal_x[SOAindex] += f.x ;
353 f_normal_y[SOAindex] += f.y ;
354 f_normal_z[SOAindex] += f.z ;
356 r_virial.xx = f.x * pos.x;
357 r_virial.xy = f.x * pos.y;
358 r_virial.xz = f.x * pos.z;
359 r_virial.yx = f.y * pos.x;
360 r_virial.yy = f.y * pos.y;
361 r_virial.yz = f.y * pos.z;
362 r_virial.zx = f.z * pos.x;
363 r_virial.zy = f.z * pos.y;
364 r_virial.zz = f.z * pos.z;
368 if(threadIdx.x == 0){
370 unsigned int value = atomicInc(&tbcatomic[0], totaltb);
371 isLastBlockDone = (value == (totaltb -1));
375 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
377 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
379 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
382 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
384 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
386 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
389 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
391 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
393 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
396 if(threadIdx.x == 0){
397 // thread zero updates the value
398 h_curCM->x = cm.x; // update current center of mass
399 h_curCM->y = cm.y; // update current center of mass
400 h_curCM->z = cm.z; // update current center of mass
401 h_extEnergy[0] = energy; // bias energy
402 h_extForce->x = group_f.x; // bias force
403 h_extForce->y = group_f.y;
404 h_extForce->z = group_f.z;
406 h_extVirial->xx = r_virial.xx;
407 h_extVirial->xy = r_virial.xy;
408 h_extVirial->xz = r_virial.xz;
409 h_extVirial->yx = r_virial.yx;
410 h_extVirial->yy = r_virial.yy;
411 h_extVirial->yz = r_virial.yz;
412 h_extVirial->zx = r_virial.zx;
413 h_extVirial->zy = r_virial.zy;
414 h_extVirial->zz = r_virial.zz;
417 // last block cleans up
420 if(threadIdx.x == 0){
421 // zero out for next iteration
425 //resets atomic counter
433 /*! Compute SMD force and virial on group of atoms */
434 void computeSMDForce(
436 const double inv_group_mass,
437 const double spring_constant,
438 const double transverse_spring_constant,
439 const double velocity,
440 const double3 direction,
442 const int currentTime,
444 const double3 origCM,
446 const double* d_pos_x,
447 const double* d_pos_y,
448 const double* d_pos_z,
449 const char3* d_transform,
450 double * d_f_normal_x,
451 double * d_f_normal_y,
452 double * d_f_normal_z,
453 const int numSMDAtoms,
454 const int* d_smdAtomsSOAIndex,
458 cudaTensor* d_extVirial,
461 cudaTensor* h_extVirial,
462 unsigned int* d_tbcatomic,
463 const int numDevices,
464 const int deviceIndex,
468 const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
469 const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
471 #define CALL_WITH_COM(DOENERGY, MGPUON) \
472 computeSMDForceWithCOMKernel<DOENERGY, MGPUON> \
473 <<< grid, blocks, 0 , stream >>> \
474 (numSMDAtoms, lat, inv_group_mass, spring_constant, \
475 transverse_spring_constant, velocity, direction, currentTime, \
476 origCM, d_mass, d_pos_x, d_pos_y, d_pos_z, d_transform, \
477 d_f_normal_x, d_f_normal_y, d_f_normal_z, d_smdAtomsSOAIndex, \
478 d_extVirial, h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce, \
479 h_extVirial, d_tbcatomic);
481 #define CALL(DOENERGY, MGPUON) \
482 computeSMDForceKernel<DOENERGY, MGPUON> \
483 <<<grid, blocks, 0, stream>>> \
484 (numSMDAtoms, lat, inv_group_mass, spring_constant, \
485 transverse_spring_constant, velocity, direction, currentTime, \
486 origCM, d_mass, d_pos_x, d_pos_y, d_pos_z, d_transform, \
487 d_f_normal_x, d_f_normal_y, d_f_normal_z, d_smdAtomsSOAIndex, \
488 h_curCM, d_curCM, d_peerCOM, h_extEnergy, h_extForce, h_extVirial, \
491 if (numSMDAtoms > 1024) {
493 { //first calculate the COM for SMD group and store it in h_curCM
494 computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
509 {// sum up the COMs across devices to this device
510 computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
514 if(doEnergy && mGpuOn) CALL_WITH_COM(true, true);
515 if(doEnergy && !mGpuOn) CALL_WITH_COM(true, false);
516 if(!doEnergy && mGpuOn) CALL_WITH_COM(false, true);
517 if(!doEnergy && !mGpuOn) CALL_WITH_COM(false, false);
522 {// sum up the COMs across devices to this device
523 computeDistCOMKernelMgpu<<<grid, blocks, 0, stream>>>(d_peerCOM,
528 if(doEnergy && mGpuOn) CALL(true, true);
529 if(doEnergy && !mGpuOn) CALL(true, false);
530 if(!doEnergy && mGpuOn) CALL(false, true);
531 if(!doEnergy && !mGpuOn) CALL(false, false);
538 void initPeerCOMmgpu(
539 const int numDevices,
540 const int deviceIndex,
541 double3** d_peerPool,
545 const int blocks = numDevices;
547 initPeerCOMKernel<<<grid, blocks, 0, stream>>>( numDevices,
554 /* called in earlier phase to handle multi device COM */
555 void computeCOMSMDMgpu(
556 const int numSMDAtoms,
559 const double* d_pos_x,
560 const double* d_pos_y,
561 const double* d_pos_z,
562 const char3* d_transform,
563 const int* d_smdAtomsSOAIndex,
565 double3** d_peer_curCM,
566 unsigned int* d_tbcatomic,
567 const int numDevices,
568 const int deviceIndex,
571 // block it up if large, otherwise all in one go
572 const int blocks = (numSMDAtoms > 1024) ? 128 : 1024;
573 const int grid = (numSMDAtoms > 1024) ? (numSMDAtoms + blocks - 1) / blocks : 1;
574 //initialize the device memory to zero here
575 cudaCheck(cudaMemset(d_peerCOM, 0, sizeof(double3)));
576 if(numSMDAtoms >1024)
577 computeCOMKernelMgpu<128><<<grid, blocks, 0, stream>>>(numSMDAtoms,
579 d_pos_x, d_pos_y, d_pos_z,
587 computeCOMKernelMgpu<1024><<<grid, blocks, 0, stream>>>(numSMDAtoms,
589 d_pos_x, d_pos_y, d_pos_z,
598 #endif // NODEGROUP_FORCE_REGISTER