NAMD
ComputeGlobalMasterVirialCUDAKernel.cu
Go to the documentation of this file.
1 #ifdef NAMD_CUDA
2 #if __CUDACC_VER_MAJOR__ >= 11
3 #include <cub/cub.cuh>
4 #else
5 #include <namd_cub/cub.cuh>
6 #endif
7 #else // NAMD_HIP
8 #include <hip/hip_runtime.h>
9 #include <hipcub/hipcub.hpp>
10 #define cub hipcub
11 #endif // end NAMD_CUDA vs. NAMD_HIP
12 
13 #include "HipDefines.h"
14 
15 #include "ComputeGlobalMasterVirialCUDAKernel.h"
16 
17 #ifdef NODEGROUP_FORCE_REGISTER
18 
19 template <int BLOCK_SIZE>
20 __global__ void computeGlobalMasterVirialKernel(
21  const int numAtoms,
22  CudaLocalRecord* localRecords,
23  const double* __restrict d_pos_x,
24  const double* __restrict d_pos_y,
25  const double* __restrict d_pos_z,
26  const char3* __restrict d_transform,
27  double* __restrict f_global_x,
28  double* __restrict f_global_y,
29  double* __restrict f_global_z,
30  double3* __restrict d_extForce,
31  double3* __restrict h_extForce,
32  cudaTensor* __restrict d_virial,
33  cudaTensor* __restrict h_extVirial,
34  const Lattice lat,
35  unsigned int* __restrict tbcatomic
36 )
37 {
38  double3 r_netForce = {0, 0, 0};
39  cudaTensor r_virial;
40  r_virial.xx = 0.0; r_virial.xy = 0.0; r_virial.xz = 0.0;
41  r_virial.yx = 0.0; r_virial.yy = 0.0; r_virial.yz = 0.0;
42  r_virial.zx = 0.0; r_virial.zy = 0.0; r_virial.zz = 0.0;
43  int totaltb = gridDim.x;
44  int i = threadIdx.x + blockIdx.x*blockDim.x;
45  __shared__ bool isLastBlockDone;
46 
47  if(threadIdx.x == 0){
48  isLastBlockDone = 0;
49  }
50 
51  __syncthreads();
52 
53  if (i < numAtoms) {
54  double3 pos, pos_i;
55  pos.x = d_pos_x[i];
56  pos.y = d_pos_y[i];
57  pos.z = d_pos_z[i];
58  const char3 t = d_transform[i];
59  pos_i = lat.reverse_transform(pos, t);
60  r_virial.xx = f_global_x[i] * pos_i.x;
61  r_virial.xy = f_global_x[i] * pos_i.y;
62  r_virial.xz = f_global_x[i] * pos_i.z;
63  r_virial.yx = f_global_y[i] * pos_i.x;
64  r_virial.yy = f_global_y[i] * pos_i.y;
65  r_virial.yz = f_global_y[i] * pos_i.z;
66  r_virial.zx = f_global_z[i] * pos_i.x;
67  r_virial.zy = f_global_z[i] * pos_i.y;
68  r_virial.zz = f_global_z[i] * pos_i.z;
69  r_netForce.x = f_global_x[i];
70  r_netForce.y = f_global_y[i];
71  r_netForce.z = f_global_z[i];
72  }
73  __syncthreads();
74 
75  typedef cub::BlockReduce<double, BLOCK_SIZE> BlockReduce;
76  __shared__ typename BlockReduce::TempStorage temp_storage;
77 
78  r_netForce.x = BlockReduce(temp_storage).Sum(r_netForce.x);
79  __syncthreads();
80  r_netForce.y = BlockReduce(temp_storage).Sum(r_netForce.y);
81  __syncthreads();
82  r_netForce.z = BlockReduce(temp_storage).Sum(r_netForce.z);
83  __syncthreads();
84 
85  r_virial.xx = BlockReduce(temp_storage).Sum(r_virial.xx);
86  __syncthreads();
87  r_virial.xy = BlockReduce(temp_storage).Sum(r_virial.xy);
88  __syncthreads();
89  r_virial.xz = BlockReduce(temp_storage).Sum(r_virial.xz);
90  __syncthreads();
91 
92  r_virial.yx = BlockReduce(temp_storage).Sum(r_virial.yx);
93  __syncthreads();
94  r_virial.yy = BlockReduce(temp_storage).Sum(r_virial.yy);
95  __syncthreads();
96  r_virial.yz = BlockReduce(temp_storage).Sum(r_virial.yz);
97  __syncthreads();
98 
99  r_virial.zx = BlockReduce(temp_storage).Sum(r_virial.zx);
100  __syncthreads();
101  r_virial.zy = BlockReduce(temp_storage).Sum(r_virial.zy);
102  __syncthreads();
103  r_virial.zz = BlockReduce(temp_storage).Sum(r_virial.zz);
104  __syncthreads();
105 
106  if(threadIdx.x == 0){
107  atomicAdd(&(d_virial->xx), r_virial.xx);
108  atomicAdd(&(d_virial->xy), r_virial.xy);
109  atomicAdd(&(d_virial->xz), r_virial.xz);
110 
111  atomicAdd(&(d_virial->yx), r_virial.yx);
112  atomicAdd(&(d_virial->yy), r_virial.yy);
113  atomicAdd(&(d_virial->yz), r_virial.yz);
114 
115  atomicAdd(&(d_virial->zx), r_virial.zx);
116  atomicAdd(&(d_virial->zy), r_virial.zy);
117  atomicAdd(&(d_virial->zz), r_virial.zz);
118 
119  atomicAdd(&(d_extForce->x), r_netForce.x);
120  atomicAdd(&(d_extForce->y), r_netForce.y);
121  atomicAdd(&(d_extForce->z), r_netForce.z);
122 
123  __threadfence();
124  unsigned int value = atomicInc(&tbcatomic[0], totaltb);
125  isLastBlockDone = (value == (totaltb -1));
126  }
127  __syncthreads();
128 
129  if(isLastBlockDone){
130  if(threadIdx.x == 0){
131  h_extVirial->xx = d_virial->xx;
132  h_extVirial->xy = d_virial->xy;
133  h_extVirial->xz = d_virial->xz;
134  h_extVirial->yx = d_virial->yx;
135  h_extVirial->yy = d_virial->yy;
136  h_extVirial->yz = d_virial->yz;
137  h_extVirial->zx = d_virial->zx;
138  h_extVirial->zy = d_virial->zy;
139  h_extVirial->zz = d_virial->zz;
140 
141  //reset the device virial value
142  d_virial->xx = 0;
143  d_virial->xy = 0;
144  d_virial->xz = 0;
145  d_virial->yx = 0;
146  d_virial->yy = 0;
147  d_virial->yz = 0;
148  d_virial->zx = 0;
149  d_virial->zy = 0;
150  d_virial->zz = 0;
151 
152  h_extForce->x = d_extForce->x;
153  h_extForce->y = d_extForce->y;
154  h_extForce->z = d_extForce->z;
155  d_extForce->x =0 ;
156  d_extForce->y =0 ;
157  d_extForce->z =0 ;
158  //resets atomic counter
159  tbcatomic[0] = 0;
160  __threadfence();
161  }
162  }
163 }
164 
165 void computeGlobalMasterVirial(
166  const int numPatches,
167  const int numAtoms,
168  CudaLocalRecord* localRecords,
169  const double* __restrict d_pos_x,
170  const double* __restrict d_pos_y,
171  const double* __restrict d_pos_z,
172  const char3* __restrict d_transform,
173  double* __restrict f_global_x,
174  double* __restrict f_global_y,
175  double* __restrict f_global_z,
176  double3* __restrict d_extForce,
177  double3* __restrict h_extForce,
178  cudaTensor* __restrict d_virial,
179  cudaTensor* __restrict h_virial,
180  const Lattice lat,
181  unsigned int* __restrict d_tbcatomic,
182  cudaStream_t stream
183 )
184 {
185 
186  const int atom_blocks = 128;
187  const int grid = (numAtoms + atom_blocks - 1) / atom_blocks;
188  computeGlobalMasterVirialKernel<atom_blocks><<<grid, atom_blocks, 0, stream>>>(
189  numAtoms,
190  localRecords,
191  d_pos_x,
192  d_pos_y,
193  d_pos_z,
194  d_transform,
195  f_global_x,
196  f_global_y,
197  f_global_z,
198  d_extForce,
199  h_extForce,
200  d_virial,
201  h_virial,
202  lat,
203  d_tbcatomic
204  );
205 }
206 
207 #endif // NODEGROUP_FORCE_REGISTER
208