NAMD
Classes | Public Member Functions | List of all members
CudaPmeKSpaceCompute Class Reference

#include <CudaPmeSolverUtil.h>

Inheritance diagram for CudaPmeKSpaceCompute:
PmeKSpaceCompute

Public Member Functions

 CudaPmeKSpaceCompute (PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa, int deviceID, cudaStream_t stream)
 
 ~CudaPmeKSpaceCompute ()
 
void solve (Lattice &lattice, const bool doEnergy, const bool doVirial, float *data)
 
double getEnergy ()
 
void getVirial (double *virial)
 
void energyAndVirialSetCallback (CudaPmePencilXYZ *pencilPtr)
 
void energyAndVirialSetCallback (CudaPmePencilZ *pencilPtr)
 
- Public Member Functions inherited from PmeKSpaceCompute
 PmeKSpaceCompute (PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa)
 
virtual ~PmeKSpaceCompute ()
 

Additional Inherited Members

- Protected Attributes inherited from PmeKSpaceCompute
PmeGrid pmeGrid
 
double * bm1
 
double * bm2
 
double * bm3
 
double kappa
 
const int permutation
 
const int jblock
 
const int kblock
 
int size1
 
int size2
 
int size3
 
int j0
 
int k0
 

Detailed Description

Definition at line 86 of file CudaPmeSolverUtil.h.

Constructor & Destructor Documentation

CudaPmeKSpaceCompute::CudaPmeKSpaceCompute ( PmeGrid  pmeGrid,
const int  permutation,
const int  jblock,
const int  kblock,
double  kappa,
int  deviceID,
cudaStream_t  stream 
)

Definition at line 332 of file CudaPmeSolverUtil.C.

References PmeKSpaceCompute::bm1, PmeKSpaceCompute::bm2, PmeKSpaceCompute::bm3, cudaCheck, PmeGrid::K1, PmeGrid::K2, and PmeGrid::K3.

333  :
335  deviceID(deviceID), stream(stream) {
336 
337  cudaCheck(cudaSetDevice(deviceID));
338 
339  // Copy bm1 -> prefac_x on GPU memory
340  float *bm1f = new float[pmeGrid.K1];
341  float *bm2f = new float[pmeGrid.K2];
342  float *bm3f = new float[pmeGrid.K3];
343  for (int i=0;i < pmeGrid.K1;i++) bm1f[i] = (float)bm1[i];
344  for (int i=0;i < pmeGrid.K2;i++) bm2f[i] = (float)bm2[i];
345  for (int i=0;i < pmeGrid.K3;i++) bm3f[i] = (float)bm3[i];
346  allocate_device<float>(&d_bm1, pmeGrid.K1);
347  allocate_device<float>(&d_bm2, pmeGrid.K2);
348  allocate_device<float>(&d_bm3, pmeGrid.K3);
349  copy_HtoD_sync<float>(bm1f, d_bm1, pmeGrid.K1);
350  copy_HtoD_sync<float>(bm2f, d_bm2, pmeGrid.K2);
351  copy_HtoD_sync<float>(bm3f, d_bm3, pmeGrid.K3);
352  delete [] bm1f;
353  delete [] bm2f;
354  delete [] bm3f;
355  allocate_device<EnergyVirial>(&d_energyVirial, 1);
356  allocate_host<EnergyVirial>(&h_energyVirial, 1);
357  // cudaCheck(cudaEventCreateWithFlags(&copyEnergyVirialEvent, cudaEventDisableTiming));
358  cudaCheck(cudaEventCreate(&copyEnergyVirialEvent));
359  // ncall = 0;
360 }
int K2
Definition: PmeBase.h:18
int K1
Definition: PmeBase.h:18
__thread cudaStream_t stream
PmeKSpaceCompute(PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa)
int K3
Definition: PmeBase.h:18
const int permutation
#define cudaCheck(stmt)
Definition: CudaUtils.h:95
CudaPmeKSpaceCompute::~CudaPmeKSpaceCompute ( )

Definition at line 362 of file CudaPmeSolverUtil.C.

References cudaCheck.

362  {
363  cudaCheck(cudaSetDevice(deviceID));
364  deallocate_device<float>(&d_bm1);
365  deallocate_device<float>(&d_bm2);
366  deallocate_device<float>(&d_bm3);
367  deallocate_device<EnergyVirial>(&d_energyVirial);
368  deallocate_host<EnergyVirial>(&h_energyVirial);
369  cudaCheck(cudaEventDestroy(copyEnergyVirialEvent));
370 }
#define cudaCheck(stmt)
Definition: CudaUtils.h:95

Member Function Documentation

void CudaPmeKSpaceCompute::energyAndVirialSetCallback ( CudaPmePencilXYZ pencilPtr)

Definition at line 522 of file CudaPmeSolverUtil.C.

References CcdCallBacksReset(), and cudaCheck.

522  {
523  cudaCheck(cudaSetDevice(deviceID));
524  pencilXYZPtr = pencilPtr;
525  pencilZPtr = NULL;
526  checkCount = 0;
527  CcdCallBacksReset(0, CmiWallTimer());
528  // Set the call back at 0.1ms
529  CcdCallFnAfter(energyAndVirialCheck, this, 0.1);
530 }
void CcdCallBacksReset(void *ignored, double curWallTime)
#define cudaCheck(stmt)
Definition: CudaUtils.h:95
void CudaPmeKSpaceCompute::energyAndVirialSetCallback ( CudaPmePencilZ pencilPtr)

Definition at line 532 of file CudaPmeSolverUtil.C.

References CcdCallBacksReset(), and cudaCheck.

532  {
533  cudaCheck(cudaSetDevice(deviceID));
534  pencilXYZPtr = NULL;
535  pencilZPtr = pencilPtr;
536  checkCount = 0;
537  CcdCallBacksReset(0, CmiWallTimer());
538  // Set the call back at 0.1ms
539  CcdCallFnAfter(energyAndVirialCheck, this, 0.1);
540 }
void CcdCallBacksReset(void *ignored, double curWallTime)
#define cudaCheck(stmt)
Definition: CudaUtils.h:95
double CudaPmeKSpaceCompute::getEnergy ( )
virtual

Implements PmeKSpaceCompute.

Definition at line 542 of file CudaPmeSolverUtil.C.

542  {
543  return h_energyVirial->energy;
544 }
void CudaPmeKSpaceCompute::getVirial ( double *  virial)
virtual

Implements PmeKSpaceCompute.

Definition at line 546 of file CudaPmeSolverUtil.C.

References Perm_cX_Y_Z, Perm_Z_cX_Y, and PmeKSpaceCompute::permutation.

546  {
547  if (permutation == Perm_Z_cX_Y) {
548  // h_energyVirial->virial is storing ZZ, ZX, ZY, XX, XY, YY
549  virial[0] = h_energyVirial->virial[3];
550  virial[1] = h_energyVirial->virial[4];
551  virial[2] = h_energyVirial->virial[1];
552 
553  virial[3] = h_energyVirial->virial[4];
554  virial[4] = h_energyVirial->virial[5];
555  virial[5] = h_energyVirial->virial[2];
556 
557  virial[6] = h_energyVirial->virial[1];
558  virial[7] = h_energyVirial->virial[7];
559  virial[8] = h_energyVirial->virial[0];
560  } else if (permutation == Perm_cX_Y_Z) {
561  // h_energyVirial->virial is storing XX, XY, XZ, YY, YZ, ZZ
562  virial[0] = h_energyVirial->virial[0];
563  virial[1] = h_energyVirial->virial[1];
564  virial[2] = h_energyVirial->virial[2];
565 
566  virial[3] = h_energyVirial->virial[1];
567  virial[4] = h_energyVirial->virial[3];
568  virial[5] = h_energyVirial->virial[4];
569 
570  virial[6] = h_energyVirial->virial[2];
571  virial[7] = h_energyVirial->virial[4];
572  virial[8] = h_energyVirial->virial[5];
573  }
574 }
const int permutation
void CudaPmeKSpaceCompute::solve ( Lattice lattice,
const bool  doEnergy,
const bool  doVirial,
float *  data 
)
virtual

Implements PmeKSpaceCompute.

Definition at line 372 of file CudaPmeSolverUtil.C.

References Lattice::a(), Lattice::a_r(), Lattice::b(), Lattice::b_r(), Lattice::c(), Lattice::c_r(), cudaCheck, PmeKSpaceCompute::j0, PmeKSpaceCompute::k0, PmeGrid::K1, PmeGrid::K2, PmeGrid::K3, PmeKSpaceCompute::kappa, NAMD_bug(), Perm_cX_Y_Z, Perm_Z_cX_Y, PmeKSpaceCompute::permutation, PmeKSpaceCompute::pmeGrid, scalar_sum(), PmeKSpaceCompute::size1, PmeKSpaceCompute::size2, PmeKSpaceCompute::size3, Lattice::volume(), Vector::x, Vector::y, and Vector::z.

372  {
373 #if 0
374  // Check lattice to make sure it is updating for constant pressure
375  fprintf(stderr, "K-SPACE LATTICE %g %g %g %g %g %g %g %g %g\n",
376  lattice.a().x, lattice.a().y, lattice.a().z,
377  lattice.b().x, lattice.b().y, lattice.b().z,
378  lattice.c().x, lattice.c().y, lattice.c().z);
379 #endif
380  cudaCheck(cudaSetDevice(deviceID));
381 
382  const bool doEnergyVirial = (doEnergy || doVirial);
383 
384  int nfft1, nfft2, nfft3;
385  float *prefac1, *prefac2, *prefac3;
386 
387  BigReal volume = lattice.volume();
388  Vector a_r = lattice.a_r();
389  Vector b_r = lattice.b_r();
390  Vector c_r = lattice.c_r();
391  float recip1x, recip1y, recip1z;
392  float recip2x, recip2y, recip2z;
393  float recip3x, recip3y, recip3z;
394 
395  if (permutation == Perm_Z_cX_Y) {
396  // Z, X, Y
397  nfft1 = pmeGrid.K3;
398  nfft2 = pmeGrid.K1;
399  nfft3 = pmeGrid.K2;
400  prefac1 = d_bm3;
401  prefac2 = d_bm1;
402  prefac3 = d_bm2;
403  recip1x = c_r.z;
404  recip1y = c_r.x;
405  recip1z = c_r.y;
406  recip2x = a_r.z;
407  recip2y = a_r.x;
408  recip2z = a_r.y;
409  recip3x = b_r.z;
410  recip3y = b_r.x;
411  recip3z = b_r.y;
412  } else if (permutation == Perm_cX_Y_Z) {
413  // X, Y, Z
414  nfft1 = pmeGrid.K1;
415  nfft2 = pmeGrid.K2;
416  nfft3 = pmeGrid.K3;
417  prefac1 = d_bm1;
418  prefac2 = d_bm2;
419  prefac3 = d_bm3;
420  recip1x = a_r.x;
421  recip1y = a_r.y;
422  recip1z = a_r.z;
423  recip2x = b_r.x;
424  recip2y = b_r.y;
425  recip2z = b_r.z;
426  recip3x = c_r.x;
427  recip3y = c_r.y;
428  recip3z = c_r.z;
429  } else {
430  NAMD_bug("CudaPmeKSpaceCompute::solve, invalid permutation");
431  }
432 
433  // ncall++;
434  // if (ncall == 1) {
435  // char filename[256];
436  // sprintf(filename,"dataf_%d_%d.txt",jblock,kblock);
437  // writeComplexToDisk((float2*)data, size1*size2*size3, filename, stream);
438  // }
439 
440  // if (ncall == 1) {
441  // float2* h_data = new float2[size1*size2*size3];
442  // float2* d_data = (float2*)data;
443  // copy_DtoH<float2>(d_data, h_data, size1*size2*size3, stream);
444  // cudaCheck(cudaStreamSynchronize(stream));
445  // FILE *handle = fopen("dataf.txt", "w");
446  // for (int z=0;z < pmeGrid.K3;z++) {
447  // for (int y=0;y < pmeGrid.K2;y++) {
448  // for (int x=0;x < pmeGrid.K1/2+1;x++) {
449  // int i;
450  // if (permutation == Perm_cX_Y_Z) {
451  // i = x + y*size1 + z*size1*size2;
452  // } else {
453  // i = z + x*size1 + y*size1*size2;
454  // }
455  // fprintf(handle, "%f %f\n", h_data[i].x, h_data[i].y);
456  // }
457  // }
458  // }
459  // fclose(handle);
460  // delete [] h_data;
461  // }
462 
463  // Clear energy and virial array if needed
464  if (doEnergyVirial) clear_device_array<EnergyVirial>(d_energyVirial, 1, stream);
465 
466  scalar_sum(permutation == Perm_cX_Y_Z, nfft1, nfft2, nfft3, size1, size2, size3, kappa,
467  recip1x, recip1y, recip1z, recip2x, recip2y, recip2z, recip3x, recip3y, recip3z,
468  volume, prefac1, prefac2, prefac3, j0, k0, doEnergyVirial,
469  &d_energyVirial->energy, d_energyVirial->virial, (float2*)data,
470  stream);
471 
472  // Copy energy and virial to host if needed
473  if (doEnergyVirial) {
474  copy_DtoH<EnergyVirial>(d_energyVirial, h_energyVirial, 1, stream);
475  cudaCheck(cudaEventRecord(copyEnergyVirialEvent, stream));
476  // cudaCheck(cudaStreamSynchronize(stream));
477  }
478 
479 }
Vector a_r() const
Definition: Lattice.h:268
void scalar_sum(const bool orderXYZ, const int nfft1, const int nfft2, const int nfft3, const int size1, const int size2, const int size3, const double kappa, const float recip1x, const float recip1y, const float recip1z, const float recip2x, const float recip2y, const float recip2z, const float recip3x, const float recip3y, const float recip3z, const double volume, const float *prefac1, const float *prefac2, const float *prefac3, const int k2_00, const int k3_00, const bool doEnergyVirial, double *energy, double *virial, float2 *data, cudaStream_t stream)
Definition: Vector.h:64
int K2
Definition: PmeBase.h:18
int K1
Definition: PmeBase.h:18
Vector c_r() const
Definition: Lattice.h:270
BigReal z
Definition: Vector.h:66
Vector b_r() const
Definition: Lattice.h:269
__thread cudaStream_t stream
void NAMD_bug(const char *err_msg)
Definition: common.C:129
BigReal x
Definition: Vector.h:66
BigReal volume(void) const
Definition: Lattice.h:277
int K3
Definition: PmeBase.h:18
const int permutation
BigReal y
Definition: Vector.h:66
Vector b() const
Definition: Lattice.h:253
#define cudaCheck(stmt)
Definition: CudaUtils.h:95
Vector a() const
Definition: Lattice.h:252
Vector c() const
Definition: Lattice.h:254
double BigReal
Definition: common.h:114

The documentation for this class was generated from the following files: