2 #if __CUDACC_VER_MAJOR__ >= 11
5 #include <namd_cub/cub.cuh>
10 #include <hip/hip_runtime.h>
11 #include <hipcub/hipcub.hpp>
15 #include "ComputeSMDCUDAKernel.h"
16 #include "ComputeCOMCudaKernel.h"
17 #include "HipDefines.h"
19 #ifdef NODEGROUP_FORCE_REGISTER
22 /*! Calculate SMD force and virial for large atom group (numSMDAtoms > 1024)
23 Multiple thread block will be called to do this operation.
24 The current COM (curCOM) must be calculated and pssed to this function. */
25 template<bool T_DOENERGY>
26 __global__ void computeSMDForceWithCOMKernel(
27 const int numSMDAtoms,
29 const double inv_group_mass,
32 const double velocity,
33 const double3 direction,
34 const int currentTime,
36 const float * __restrict mass,
37 const double* __restrict pos_x,
38 const double* __restrict pos_y,
39 const double* __restrict pos_z,
40 const char3* __restrict transform,
41 double* __restrict f_normal_x,
42 double* __restrict f_normal_y,
43 double* __restrict f_normal_z,
44 const int* __restrict smdAtomsSOAIndex,
45 cudaTensor* __restrict d_virial,
46 const double3* __restrict curCOM,
47 double* __restrict h_extEnergy,
48 double3* __restrict h_extForce,
49 cudaTensor* __restrict h_extVirial,
50 unsigned int* __restrict tbcatomic)
52 int tid = threadIdx.x + blockIdx.x * blockDim.x;
53 int totaltb = gridDim.x;
54 bool isLastBlockDone = 0;
56 double3 group_f = {0, 0, 0};
58 double3 pos = {0, 0, 0};
59 double3 f = {0, 0, 0};
61 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
62 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
63 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
65 if(tid < numSMDAtoms){
66 // First -> recalculate center of mass.
67 // Only thread zero is doing this
68 SOAindex = smdAtomsSOAIndex[tid];
70 // uncoalesced memory access: too bad!
71 double m = mass[SOAindex]; // Cast from float to double here
72 pos.x = pos_x[SOAindex];
73 pos.y = pos_y[SOAindex];
74 pos.z = pos_z[SOAindex];
76 // calculate the distance difference along direction
78 diffCOM.x = curCOM->x - origCM.x;
79 diffCOM.y = curCOM->y - origCM.y;
80 diffCOM.z = curCOM->z - origCM.z;
81 double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
82 diffCOM.z*direction.z;
84 // Ok so we've calculated the new center of mass, now we can calculate the bias
85 double preFactor = (velocity*currentTime - diff);
86 group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
87 group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
88 group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
90 // calculate the force on each atom
91 f.x = group_f.x * m * inv_group_mass;
92 f.y = group_f.y * m * inv_group_mass;
93 f.z = group_f.z * m * inv_group_mass;
96 f_normal_x[SOAindex] += f.x ;
97 f_normal_y[SOAindex] += f.y ;
98 f_normal_z[SOAindex] += f.z ;
100 // energy for restraint along the direction
101 energy = 0.5*k*preFactor*preFactor;
102 // energy for transverse restraint
103 energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
104 diffCOM.z*diffCOM.z - diff*diff);
105 // unwrap coordinates before calculating the virial
106 char3 t = transform[SOAindex];
107 pos = lat.reverse_transform(pos, t);
108 r_virial.xx = f.x * pos.x;
109 r_virial.xy = f.x * pos.y;
110 r_virial.xz = f.x * pos.z;
111 r_virial.yx = f.y * pos.x;
112 r_virial.yy = f.y * pos.y;
113 r_virial.yz = f.y * pos.z;
114 r_virial.zx = f.z * pos.x;
115 r_virial.zy = f.z * pos.y;
116 r_virial.zz = f.z * pos.z;
122 typedef cub::BlockReduce<double, 128> BlockReduce;
123 __shared__ typename BlockReduce::TempStorage temp_storage;
125 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
127 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
129 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
132 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
134 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
136 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
139 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
141 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
143 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
146 if(threadIdx.x == 0){
147 atomicAdd(&(d_virial->xx), r_virial.xx);
148 atomicAdd(&(d_virial->xy), r_virial.xy);
149 atomicAdd(&(d_virial->xz), r_virial.xz);
151 atomicAdd(&(d_virial->yx), r_virial.yx);
152 atomicAdd(&(d_virial->yy), r_virial.yy);
153 atomicAdd(&(d_virial->yz), r_virial.yz);
155 atomicAdd(&(d_virial->zx), r_virial.zx);
156 atomicAdd(&(d_virial->zy), r_virial.zy);
157 atomicAdd(&(d_virial->zz), r_virial.zz);
160 unsigned int value = atomicInc(&tbcatomic[0], totaltb);
161 isLastBlockDone = (value == (totaltb -1));
165 // Last block will set the host values
167 if(threadIdx.x == 0){
168 h_extEnergy[0] = energy;
169 h_extForce->x = group_f.x;
170 h_extForce->y = group_f.y;
171 h_extForce->z = group_f.z;
173 h_extVirial->xx = d_virial->xx;
174 h_extVirial->xy = d_virial->xy;
175 h_extVirial->xz = d_virial->xz;
176 h_extVirial->yx = d_virial->yx;
177 h_extVirial->yy = d_virial->yy;
178 h_extVirial->yz = d_virial->yz;
179 h_extVirial->zx = d_virial->zx;
180 h_extVirial->zy = d_virial->zy;
181 h_extVirial->zz = d_virial->zz;
183 //reset the device virial value
195 //resets atomic counter
204 /*! Calculate SMD force, virial, and COM for small atom group (numSMDAtoms <= 1024)
205 Single thread block will be called to do this operation.
206 The current COM will be calculated and stored in h_curCM. */
207 template<bool T_DOENERGY>
208 __global__ void computeSMDForceKernel(
209 const int numSMDAtoms,
211 const double inv_group_mass,
214 const double velocity,
215 const double3 direction,
216 const int currentTime,
217 const double3 origCM,
218 const float * __restrict mass,
219 const double* __restrict pos_x,
220 const double* __restrict pos_y,
221 const double* __restrict pos_z,
222 const char3* __restrict transform,
223 double* __restrict f_normal_x,
224 double* __restrict f_normal_y,
225 double* __restrict f_normal_z,
226 const int* __restrict smdAtomsSOAIndex,
227 double3* __restrict h_curCM,
228 double* __restrict h_extEnergy,
229 double3* __restrict h_extForce,
230 cudaTensor* __restrict h_extVirial)
232 __shared__ double3 group_f;
233 __shared__ double energy;
234 int tid = threadIdx.x + blockIdx.x * blockDim.x;
236 double3 cm = {0, 0, 0};
237 double3 pos = {0, 0, 0};
238 double3 f = {0, 0, 0};
240 r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
241 r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
242 r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
244 if(tid < numSMDAtoms){
245 // First -> recalculate center of mass.
246 // Only thread zero is doing this
247 SOAindex = smdAtomsSOAIndex[tid];
249 // uncoalesced memory access: too bad!
250 m = mass[SOAindex]; // Cast from float to double here
251 pos.x = pos_x[SOAindex];
252 pos.y = pos_y[SOAindex];
253 pos.z = pos_z[SOAindex];
255 // unwrap the coordinate to calculate COM
256 char3 t = transform[SOAindex];
257 pos = lat.reverse_transform(pos, t);
264 // now reduce the values and add it to thread zero
265 typedef cub::BlockReduce<double, 1024> BlockReduce;
266 __shared__ typename BlockReduce::TempStorage temp_storage;
268 cm.x = BlockReduce(temp_storage).Sum(cm.x);
270 cm.y = BlockReduce(temp_storage).Sum(cm.y);
272 cm.z = BlockReduce(temp_storage).Sum(cm.z);
275 // Calculate group force and acceleration
276 if(threadIdx.x == 0){
277 cm.x *= inv_group_mass; // calculates the current center of mass
278 cm.y *= inv_group_mass; // calculates the current center of mass
279 cm.z *= inv_group_mass; // calculates the current center of mass
281 // calculate the distance difference along direction
283 diffCOM.x = cm.x - origCM.x;
284 diffCOM.y = cm.y - origCM.y;
285 diffCOM.z = cm.z - origCM.z;
286 double diff = diffCOM.x*direction.x + diffCOM.y*direction.y +
287 diffCOM.z*direction.z;
289 // Ok so we've calculated the new center of mass, now we can calculate the bias
290 double preFactor = (velocity*currentTime - diff);
291 group_f.x = k*preFactor*direction.x + k2*(diff*direction.x - diffCOM.x);
292 group_f.y = k*preFactor*direction.y + k2*(diff*direction.y - diffCOM.y);
293 group_f.z = k*preFactor*direction.z + k2*(diff*direction.z - diffCOM.z);
295 // energy for restraint along the direction
296 energy = 0.5*k*preFactor*preFactor;
297 // energy for transverse restraint
298 energy += 0.5*k2*(diffCOM.x*diffCOM.x + diffCOM.y*diffCOM.y +
299 diffCOM.z*diffCOM.z - diff*diff);
304 if(tid < numSMDAtoms){
305 // calculate the force on each atom
306 f.x = group_f.x * m * inv_group_mass;
307 f.y = group_f.y * m * inv_group_mass;
308 f.z = group_f.z * m * inv_group_mass;
311 f_normal_x[SOAindex] += f.x ;
312 f_normal_y[SOAindex] += f.y ;
313 f_normal_z[SOAindex] += f.z ;
315 r_virial.xx = f.x * pos.x;
316 r_virial.xy = f.x * pos.y;
317 r_virial.xz = f.x * pos.z;
318 r_virial.yx = f.y * pos.x;
319 r_virial.yy = f.y * pos.y;
320 r_virial.yz = f.y * pos.z;
321 r_virial.zx = f.z * pos.x;
322 r_virial.zy = f.z * pos.y;
323 r_virial.zz = f.z * pos.z;
327 r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
329 r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
331 r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
334 r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
336 r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
338 r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
341 r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
343 r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
345 r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
348 if(threadIdx.x == 0){
349 // thread zero updates the value
350 h_curCM->x = cm.x; // update current center of mass
351 h_curCM->y = cm.y; // update current center of mass
352 h_curCM->z = cm.z; // update current center of mass
354 h_extEnergy[0] = energy; // bias energy
355 h_extForce->x = group_f.x; // bias force
356 h_extForce->y = group_f.y;
357 h_extForce->z = group_f.z;
359 h_extVirial->xx = r_virial.xx;
360 h_extVirial->xy = r_virial.xy;
361 h_extVirial->xz = r_virial.xz;
362 h_extVirial->yx = r_virial.yx;
363 h_extVirial->yy = r_virial.yy;
364 h_extVirial->yz = r_virial.yz;
365 h_extVirial->zx = r_virial.zx;
366 h_extVirial->zy = r_virial.zy;
367 h_extVirial->zz = r_virial.zz;
372 /*! Compute SMD force and virial on group of atoms */
373 void computeSMDForce(
375 const double inv_group_mass,
376 const double spring_constant,
377 const double transverse_spring_constant,
378 const double velocity,
379 const double3 direction,
381 const int currentTime,
382 const double3 origCM,
384 const double* d_pos_x,
385 const double* d_pos_y,
386 const double* d_pos_z,
387 const char3* d_transform,
388 double * d_f_normal_x,
389 double * d_f_normal_y,
390 double * d_f_normal_z,
391 const int numSMDAtoms,
392 const int* d_smdAtomsSOAIndex,
395 cudaTensor* d_extVirial,
398 cudaTensor* h_extVirial,
399 unsigned int* d_tbcatomic,
402 if (numSMDAtoms > 1024) {
403 const int blocks = 128;
404 const int grid = (numSMDAtoms + blocks - 1) / blocks;
405 //first calculate the COM for SMD group and store it in h_curCM
406 computeCOMKernel<128><<<grid, blocks, 0, stream>>>(
421 computeSMDForceWithCOMKernel<true><<<grid, blocks, 0, stream>>>(
426 transverse_spring_constant,
447 computeSMDForceWithCOMKernel<false><<<grid, blocks, 0, stream>>>(
452 transverse_spring_constant,
474 // SMD is usually comprised of a small number of atoms. So we can get away with launching
475 // a single threadblock
476 const int blocks = 1024;
480 computeSMDForceKernel<true><<<grid, blocks, 0, stream>>>(
485 transverse_spring_constant,
504 computeSMDForceKernel<false><<<grid, blocks, 0, stream>>>(
509 transverse_spring_constant,
531 #endif // NODEGROUP_FORCE_REGISTER