/* revised short range */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "pmetest/Vector.h"
#include "pmetest/ComputePme.h"
#include "pmetest/PmeDirect.h"
//#define DEBUG_WATCH
#include "debug/debug.h"

/*
#define NB_DEBUG
*/

#undef DEBUG
//#define DEBUG

#ifndef DEBUG_SUPPORT

#ifdef DEBUG
#include <stdio.h>
#define ABORT(msg)  fprintf(stderr, "%s\n", msg), abort()
#define ASSERT(expr) \
  if (!(expr)) \
    fprintf(stderr, "assert failed: %s, %s, %d\n",#expr,__FILE__,__LINE__), \
    abort()
#else
#define ABORT(msg) abort()
#define ASSERT(expr)
#endif

#else
#define ABORT(msg) abort()

#endif

// set minimum number of expected atoms per cell
#undef MINATOMS
#define MINATOMS 10

#define REXCL 5.0

#define COLOUMB 332.0636


typedef struct AtomCell_Tag {
  int *atomid;       /* array of atom IDs (index into pos/force array) */
  int maxatomids;    /* number of entries allocated for array */
  int natomids;      /* number of atom IDs contained in this cell */
  int nbr[13];       /* list of up to 13 neighbor cells */
  int nnbrs;         /* number of neighbors in list */
  Vector nbroff[13]; /* offset for neighbor cell */
                     /* (these are zero, except for periodic images) */
} AtomCell;


double pmeDirect(int natoms, Vector f[], Vector pos[],
                 const double q[], const PmetestParams *prm)
{
  // double C = COLOUMB * prm->dielectric_1;
  double energy = 0;
  double cutoff = prm->cutoff;
  double cutoff_1 = 1.0 / cutoff;
  double cutoff2 = cutoff * cutoff;
  double r, r2, r_1, r_2, fr, dfr, fci, fcj;
  double pej;
  Vector r_ij, posj, fj, trans;
  double ewaldcof = prm->ewaldcof;
  double pi_ewaldcof = 2.0 / sqrt(M_PI) * ewaldcof;
  double tmp_a, tmp_b;
  double *pe;
  int n, nn, i, j, k, ii, jj, kk, ni, nj, nk;

/* atom cells for geometric hashing for short-ranged atomic interactions */
  AtomCell *cell;    /* array of cells */
  int ncells;        /* total number of cells (length of cell array) */
  int nxcells, nycells, nzcells;  /* number of cells in each dimension */
  int nxpect;        /* expected number of atoms in each cell */
  Vector celldis;    /* distance across each dimension of a cell */
  Vector celldis_1;  /* inverse distances for hashing */

  Vector syslo;      /* lowest corner of system domain */
  Vector sysdis;     /* distance across each dimension of system domain */

  //FLT(ewaldcof);
  //FLT(pi_ewaldcof);

  if (prm->cellvec1.y != 0 || prm->cellvec1.z != 0
      || prm->cellvec2.x != 0 || prm->cellvec2.z != 0
      || prm->cellvec3.x != 0 || prm->cellvec3.y != 0) {
    ABORT("system must be aligned with xyz coordinate axes");
  }
  sysdis.x = prm->cellvec1.x;
  sysdis.y = prm->cellvec2.y;
  sysdis.z = prm->cellvec3.z;
  syslo.x = prm->center.x - 0.5 * sysdis.x;
  syslo.y = prm->center.y - 0.5 * sysdis.y;
  syslo.z = prm->center.z - 0.5 * sysdis.z;

  if (sysdis.x < cutoff || sysdis.y < cutoff || sysdis.z < cutoff) {
    ABORT("cutoff must be no larger than domain length");
  }
  nxcells = (int) (sysdis.x * cutoff_1);
  nycells = (int) (sysdis.y * cutoff_1);
  nzcells = (int) (sysdis.z * cutoff_1);
  ncells = nxcells * nycells * nzcells;
  nxpect = 2 * (natoms / ncells);
  if (nxpect < MINATOMS)  nxpect = MINATOMS;
  ASSERT(ncells > 0);
  celldis.x = sysdis.x / nxcells;
  celldis.y = sysdis.y / nycells;
  celldis.z = sysdis.z / nzcells;
  celldis_1.x = 1.0 / celldis.x;
  celldis_1.y = 1.0 / celldis.y;
  celldis_1.z = 1.0 / celldis.z;

  /* allocate memory for partial potentials */
  if ((pe = (double *) calloc(natoms, sizeof(double))) == NULL) {
    ABORT("cannot calloc pe");
  }

  /* initialize atom cells for geometric hashing */
  if ((cell = (AtomCell *) calloc(ncells, sizeof(AtomCell))) == NULL) {
    ABORT("cannot calloc cell");
  }
  for (k = 0;  k < ncells;  k++) {
    if ((cell[k].atomid = (int *) malloc(nxpect * sizeof(int))) == NULL) {
      ABORT("cannot malloc atomid");
    }
    cell[k].maxatomids = nxpect;
  }

#ifdef NB_DEBUG
  printf("map: nxcells=%d nycells=%d nzcells=%d ncells=%d\n",
         nxcells, nycells, nzcells, ncells);
  printf("map: center= %g %g %g\n", prm->center.x,
         prm->center.y, prm->center.z);
  printf("map: sysdis= %g %g %g\n", sysdis.x, sysdis.y, sysdis.z);
  printf("map: syslo= %g %g %g\n", syslo.x, syslo.y, syslo.z);
  printf("map: celldis= %g %g %g\n", celldis.x, celldis.y, celldis.z);
#endif

  /* set up code for cell neighbor lists */
  for (k = 0;  k < nzcells;  k++) {
    for (j = 0;  j < nycells;  j++) {
      for (i = 0;  i < nxcells;  i++) {
        n = (k * nycells + j) * nxcells + i;
        ASSERT(n >= 0 && n < ncells);
#ifdef NB_DEBUG
        printf("map: cell (%d,%d,%d) => %d\n", i, j, k, n);
#endif
        /* loop thru all neighbors */
        for (kk = -1;  kk <= 1;  kk++) {
          for (jj = -1;  jj <= 1;  jj++) {
            for (ii = -1;  ii <= 1;  ii++) {
              if ((kk < 0) || (kk == 0 && jj < 0) 
                  || (kk == 0 && jj == 0 && ii <= 0))  continue;
              ASSERT(cell[n].nnbrs < 13);
              ni = i + ii;
              nj = j + jj;
              nk = k + kk;
              if (ni < 0)  trans.x = -sysdis.x;
              else if (ni >= nxcells)  trans.x = sysdis.x;
              else  trans.x = 0.0;
              if (nj < 0)  trans.y = -sysdis.y;
              else if (nj >= nycells)  trans.y = sysdis.y;
              else  trans.y = 0.0;
              if (nk < 0)  trans.z = -sysdis.z;
              else if (nk >= nzcells)  trans.z = sysdis.z;
              else  trans.z = 0.0;
              ni = (ni + nxcells) % nxcells;
              nj = (nj + nycells) % nycells;
              nk = (nk + nzcells) % nzcells;
              nn = (nk * nycells + nj) * nxcells + ni;
              ASSERT(nn >= 0 && nn < ncells);
              cell[n].nbr[ cell[n].nnbrs ] = nn;
              cell[n].nbroff[ cell[n].nnbrs ] = trans;
              cell[n].nnbrs++;
            }
          }
        }
        /* end loop thru all neighbors */
        /* sanity check -- every cell should have 13 neighbors */
        ASSERT(cell[n].nnbrs == 13);
      }
    }
  }

  /* geometric hashing of all atoms */
  for (n = 0;  n < natoms;  n++) {
    /*
     * kludge -- find image of atom in system cell
     */
    if (pos[n].x < syslo.x)  pos[n].x += sysdis.x;
    else if (pos[n].x >= syslo.x + sysdis.x)  pos[n].x -= sysdis.x;
    if (pos[n].y < syslo.y)  pos[n].y += sysdis.y;
    else if (pos[n].y >= syslo.y + sysdis.y)  pos[n].y -= sysdis.y;
    if (pos[n].z < syslo.z)  pos[n].z += sysdis.z;
    else if (pos[n].z >= syslo.z + sysdis.z)  pos[n].z -= sysdis.z;

    i = (int) floor((pos[n].x - syslo.x) * celldis_1.x);
    j = (int) floor((pos[n].y - syslo.y) * celldis_1.y);
    k = (int) floor((pos[n].z - syslo.z) * celldis_1.z);
    ASSERT(i >= 0 && i < nxcells && j >= 0 && j < nycells
           && k >= 0 && k < nzcells);
    nn = (k * nycells + j) * nxcells + i;
    ASSERT(nn < ncells);
    if (cell[nn].natomids == cell[nn].maxatomids) {
      cell[nn].maxatomids *= 2;
      cell[nn].atomid = (int *) realloc(cell[nn].atomid, 
                                        cell[nn].maxatomids * sizeof(int));
      if (cell[nn].atomid == NULL)  ABORT("call to realloc atomid");
    }
    cell[nn].atomid[ cell[nn].natomids++ ] = n;
  }

  /* compute direct part */
  for (n = 0;  n < ncells;  n++) {
    for (nn = 0;  nn < cell[n].nnbrs;  nn++) {
      nk = cell[n].nbr[nn];
      trans = cell[n].nbroff[nn];
      for (nj = 0;  nj < cell[n].natomids;  nj++) {
        j = cell[n].atomid[nj];
        posj.x = pos[j].x - trans.x;
        posj.y = pos[j].y - trans.y;
        posj.z = pos[j].z - trans.z;
        pej = pe[j];
        fj = f[j];
        for (ni = 0;  ni < cell[nk].natomids;  ni++) {
          i = cell[nk].atomid[ni];
          r_ij.x = posj.x - pos[i].x;
          r_ij.y = posj.y - pos[i].y;
          r_ij.z = posj.z - pos[i].z;
          r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;
          if (r2 >= cutoff2)  continue;
#ifdef NB_DEBUG
if (i < j) printf("atoms: %d %d\n", i, j);
else       printf("atoms: %d %d\n", j, i);
#endif
          r = sqrt(r2);
          r_1 = 1.0 / r;
          r_2 = r_1 * r_1;
        /* begin from namd */
          tmp_a = r * ewaldcof;
          tmp_b = erfc(tmp_a);
          /* pme energy */
          fr = tmp_b * r_1;
          /* pme gradient */
          dfr = (pi_ewaldcof * exp(-(tmp_a*tmp_a)) + fr) * r_2;
#if 0
#if !defined(TESTING) && !defined(MODEL)
        /* check for excluded pairs */
        /* this is a kludge for water only */
/*
          if (r < REXCL && i/3 == j/3) {
*/
          if (i/3 == j/3) {
            fr -= r_1;
            dfr -= r_2 * r_1;
          }
#endif
#endif
        /* end from namd */
          pe[i] += fr * q[j];
          pej += fr * q[i];
          fci = q[j] * dfr;
          fcj = q[i] * dfr;
          f[i].x -= fci * r_ij.x;
          f[i].y -= fci * r_ij.y;
          f[i].z -= fci * r_ij.z;
          fj.x += fcj * r_ij.x;
          fj.y += fcj * r_ij.y;
          fj.z += fcj * r_ij.z;
        }
        pe[j] = pej;
        f[j] = fj;
      }
    }
    for (nj = 1;  nj < cell[n].natomids;  nj++) {
      j = cell[n].atomid[nj];
      posj = pos[j];
      pej = pe[j];
      fj = f[j];
      for (ni = 0;  ni < nj;  ni++) {
        i = cell[n].atomid[ni];
        r_ij.x = posj.x - pos[i].x;
        r_ij.y = posj.y - pos[i].y;
        r_ij.z = posj.z - pos[i].z;
        r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;
        if (r2 >= cutoff2)  continue;
#ifdef NB_DEBUG
printf("atoms: %d %d\n", i, j);
#endif
        r = sqrt(r2);
        r_1 = 1.0 / r;
        r_2 = r_1 * r_1;
      /* begin from namd */
        tmp_a = r * ewaldcof;
        tmp_b = erfc(tmp_a);
        /* pme energy */
        fr = tmp_b * r_1;
        /* pme gradient */
        dfr = (pi_ewaldcof * exp(-(tmp_a*tmp_a)) + fr) * r_2;
#if 0
#if !defined(TESTING) && !defined(MODEL)
      /* check for excluded pairs */
      /* this is a kludge for water only */
/*
        if (r < REXCL && i/3 == j/3) {
*/
        if (i/3 == j/3) {
          fr -= r_1;
          dfr -= r_2 * r_1;
        }
#endif
#endif
      /* end from namd */
        pe[i] += fr * q[j];
        pej += fr * q[i];
        fci = q[j] * dfr;
        fcj = q[i] * dfr;
        f[i].x -= fci * r_ij.x;
        f[i].y -= fci * r_ij.y;
        f[i].z -= fci * r_ij.z;
        fj.x += fcj * r_ij.x;
        fj.y += fcj * r_ij.y;
        fj.z += fcj * r_ij.z;
      }
      pe[j] = pej;
      f[j] = fj;
    }
  }

  /* accumulate energy */
  for (i = 0;  i < natoms;  i++) {
    energy += q[i] * pe[i];
#if 0
    f[i].x *= C * q[i];
    f[i].y *= C * q[i];
    f[i].z *= C * q[i];
#endif
    f[i].x *= q[i];
    f[i].y *= q[i];
    f[i].z *= q[i];
  }
  energy *= 0.5;

  /* free all allocated memory */
  free(pe);
  for (i = 0;  i < ncells;  i++) {
    free(cell[i].atomid);
  }
  free(cell);

  /* return potential energy */
  return energy;
}
