/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * short.c
 */

#include <string.h>
#include <math.h>
#include "mgrid/split.h"
#include "debug/debug.h"

static void short_done(Mgrid *mg);
static int short_force(Mgrid *mg, MgridSystem *sys);
static int geometric_hashing(Mgrid *mg, const MgridSystem *sys);
static int setup_nbrcells(Mgrid *mg);
static int exclusions(Mgrid *mg, MgridSystem *sys);

/* splitting uses only even powers of (r/a) */
static int cell_interactions_even(Mgrid *mg, MgridSystem *sys);

/* splitting uses some odd power of (r/a) */
static int cell_interactions_odd(Mgrid *mg, MgridSystem *sys);


int mgrid_short_setup(Mgrid *mg)
{
  const MD_Dvec center = mg->param.center;
  const double len = mg->param.length;
  const double cutoff = mg->param.cutoff;
  const int32 boundary = mg->param.boundary;
  const int32 natoms = mg->param.natoms;

  /* assign virtual methods */
  mg->short_force = short_force;
  mg->short_done = short_done;

  ASSERT(cutoff > 0.0);
  ASSERT(cutoff <= len || boundary == MGRID_NONPERIODIC);
  ASSERT(boundary == MGRID_PERIODIC || boundary == MGRID_NONPERIODIC);
  ASSERT(natoms > 0);

  mg->inv_cutoff = 1.0 / cutoff;

  /*
   * periodic --> cell size is increased so cells pack exactly into domain.
   * nonperiodic --> set cell size to minimum and center cells wrt domain.
   */
  if (boundary == MGRID_PERIODIC) {
    mg->ndimcells = (int32) (len * mg->inv_cutoff);
    if (mg->ndimcells == 0) mg->ndimcells = 1;  /* in case of rounding error */
    mg->inv_cellsize = mg->ndimcells / len;
    mg->lo.x = center.x - 0.5 * len;
    mg->lo.y = center.y - 0.5 * len;
    mg->lo.z = center.z - 0.5 * len;
  }
  else {
    mg->inv_cellsize = mg->inv_cutoff;
    mg->ndimcells = (int32) ceil(len * mg->inv_cellsize);
    ASSERT(mg->ndimcells > 0);
    mg->lo.x = center.x - 0.5 * (mg->ndimcells * cutoff);
    mg->lo.y = center.y - 0.5 * (mg->ndimcells * cutoff);
    mg->lo.z = center.z - 0.5 * (mg->ndimcells * cutoff);
  }

  /* allocate grid cells */
  mg->ncells = mg->ndimcells * mg->ndimcells * mg->ndimcells;
  ASSERT(mg->ncells > 0);
  mg->cell = (MgridCell *) malloc(mg->ncells * sizeof(MgridCell));
  if (mg->cell == NULL) return MGRID_FAIL;

  /* allocate next "pointers" for cursor linked list */
  mg->next = (int32 *) malloc(natoms * sizeof(int32));
  if (mg->next == NULL) return MGRID_FAIL;

  /* setup neighbor cell lists */
  if (setup_nbrcells(mg)) return MGRID_FAIL;

  /* choose evaluation routines based on type of smoothing */
  if (mg->is_split_even_powers) {
    mg->cell_interactions = cell_interactions_even;
  }
  else {
    mg->cell_interactions = cell_interactions_odd;
  }

  return 0;
}


void short_done(Mgrid *mg)
{
  free(mg->cell);
  free(mg->next);
}


int short_force(Mgrid *mg, MgridSystem *sys)
{
  MD_Dvec *f_elec = sys->f_elec;
  MD_Dvec *f_short = sys->f_short;
  int32 natoms = mg->param.natoms;
  int32 n;

  /* geometric hashing of atoms into grid cells */
  if (geometric_hashing(mg, sys)) return MGRID_FAIL;

  /* reset memory for accumulation */
  if (f_short) memset(f_short, 0, natoms * sizeof(MD_Dvec));

  /* add force contributions between atoms in neighboring cells */
  if (mg->cell_interactions(mg, sys)) return MGRID_FAIL;

  /* subtract contributions from excluded atom pairs */
  if (sys->excl_list != NULL && exclusions(mg, sys)) return MGRID_FAIL;

  /* finish potential and force computation */
  if (f_short) {
    /* short range forces kept in separate array */
    for (n = 0;  n < natoms;  n++) {
      /* accumulate short range into total electrostatics */
      f_elec[n].x += f_short[n].x;
      f_elec[n].y += f_short[n].y;
      f_elec[n].z += f_short[n].z;
    }
  }

  /* save short range potential and accumulate into total */
  /* sys->u_short = 0.5 * accum; */
  sys->u_elec += sys->u_short;
  return 0;
}


int exclusions(Mgrid *mg, MgridSystem *sys)
{
  const MD_Dvec *p = sys->pos;
  MD_Dvec *f = (sys->f_short ? sys->f_short : sys->f_elec);
  const double *q = sys->charge;
  int32 **excl_list = sys->excl_list;
  int32 **scaled14_list = sys->scaled14_list;
  const double scaling14 = sys->scaling14;
  MD_Dvec r_ij;         /* r_ij=p[j]-p[i] */
  MD_Dvec pj, fj;       /* pj="temp p[j]", fj="temp f[j]" */
  double qj;            /* qj="temp q[j]" */
  double r2, r_2, r_1;  /* r2=||r_ij||^2, r_2=1/r2, r_1=1/r */
  double qc, uc, fc;    /* temp vars to accumulate mults */
  double f_x, f_y, f_z; /* force in x, y, z directions */
  double u = 0.0;       /* accumulate potential energy */
  int32 *excl;
  const int32 natoms = mg->param.natoms;
  int32 i, j;

  ASSERT(excl_list != NULL);

  /* loop over all atoms, process exclusions */
  for (j = 0;  j < natoms;  j++) {
    pj.x = p[j].x;
    pj.y = p[j].y;
    pj.z = p[j].z;
    qj = q[j];
    fj = f[j];

    for (excl = excl_list[j];  *excl < j;  excl++) {
      i = *excl;

      r_ij.x = pj.x - p[i].x;
      r_ij.y = pj.y - p[i].y;
      r_ij.z = pj.z - p[i].z;
      r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;

      r_2 = 1.0 / r2;
      r_1 = sqrt(r_2);

      qc = q[i] * qj;
      uc = qc * r_1;
      fc = uc * r_2;

      f_x = fc * r_ij.x;
      f_y = fc * r_ij.y;
      f_z = fc * r_ij.z;

      /* subtract this contribution */
      /* (reverse signs from evaluation routine) */
      fj.x -= f_x;
      fj.y -= f_y;
      fj.z -= f_z;
      f[i].x += f_x;
      f[i].y += f_y;
      f[i].z += f_z;
      u -= uc;
    }
    f[j] = fj;
  }

  if (scaled14_list != NULL) {

    /* loop over all atoms, process scaled 1-4 exclusions */
    for (j = 0;  j < natoms;  j++) {
      pj.x = p[j].x;
      pj.y = p[j].y;
      pj.z = p[j].z;
      qj = q[j];
      fj = f[j];

      for (excl = scaled14_list[j];  *excl < j;  excl++) {
        i = *excl;

        r_ij.x = pj.x - p[i].x;
        r_ij.y = pj.y - p[i].y;
        r_ij.z = pj.z - p[i].z;
        r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;

        r_2 = 1.0 / r2;
        r_1 = sqrt(r_2);

        /* this subtracts full contribution, adds in scaled contribution */
        qc = (1.0 - scaling14) * q[i] * qj;
        uc = qc * r_1;
        fc = uc * r2;

        f_x = fc * r_ij.x;
        f_y = fc * r_ij.y;
        f_z = fc * r_ij.z;

        /* subtract this contribution */
        /* (reverse signs from evaluation routine) */
        fj.x -= f_x;
        fj.y -= f_y;
        fj.z -= f_z;
        f[i].x += f_x;
        f[i].y += f_y;
        f[i].z += f_z;
        u -= uc;
      }
      f[j] = fj;
    }
  }

  /* add in accumulated short-range potential */
  sys->u_short += u;

  return 0;
}



/******************************************************************************
 *
 * Evaluation routine for splittings with only even powers of (r/a).
 *
 *****************************************************************************/

int cell_interactions_even(Mgrid *mg, MgridSystem *sys)
{
  const MgridCell *cell = mg->cell;
  const int32 *next = mg->next;
  const MD_Dvec *p = sys->pos;
  MD_Dvec *f = (sys->f_short ? sys->f_short : sys->f_elec);
  const double *q = sys->charge;
  const double a2 = mg->param.cutoff * mg->param.cutoff;
  const double a_1 = mg->inv_cutoff;
  const double a_2 = a_1 * a_1;
  const double two_a3 = 2.0 * a_2 * a_1;
  MD_Dvec offset;         /* periodic displacement of nbr cell k */
  MD_Dvec r_ij;           /* r_ij=p[j]-p[i] */
  MD_Dvec pj, fj;         /* pj="temp p[j]", fj="temp f[j]" */
  double qj;              /* qj="temp q[j]" */
  double r2, r_1, r_2;    /* r2=||r_ij||^2, r_1=1/r, r_2=1/r2 */
  double s, g, dg;        /* s=r2/a2, g=g(s), dg=(dg/ds) */
  double qc, uc, fc;      /* temp vars to accumulate mults */
  double f_x, f_y, f_z;   /* force in x, y, z directions */
  double u = 0.0;         /* accumulate potential energy */
  const int32 split = mg->param.split;
  const int32 ncells = mg->ncells;
  int32 nnbrs, ihead, jhead, i, j, k, n;


  /* loop over cells */
  for (n = 0;  n < ncells;  n++) {
    nnbrs = cell[n].nnbrs;
    jhead = cell[n].head;

    /* loop over all of the neighbors of this cell */
    for (k = 0;  k < nnbrs;  k++) {
      offset = cell[n].offset[k];
      ihead = cell[ cell[n].nbr[k] ].head;

      /* loop over all pairs of atoms */
      for (j = jhead;  j != -1;  j = next[j]) {

        /* subtract offset from p[j] is equivalent to adding it to p[i] */
        pj.x = p[j].x - offset.x;
        pj.y = p[j].y - offset.y;
        pj.z = p[j].z - offset.z;
        qj = q[j];
        fj = f[j];

        /* 0th neighbor is self-referential, must modify ihead */
        if (k == 0) ihead = next[j];

        for (i = ihead;  i != -1;  i = next[i]) {

          r_ij.x = pj.x - p[i].x;
          r_ij.y = pj.y - p[i].y;
          r_ij.z = pj.z - p[i].z;
          r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;

          if (r2 >= a2) continue;
          ASSERT(i != j);

          r_2 = 1.0 / r2;
          r_1 = sqrt(r_2);

          s = r2 * a_2;
          dgamma(&g, &dg, s, split);

          qc = q[i] * qj;
          uc = qc * (r_1 - a_1 * g);
          fc = qc * (r_1 * r_2 + two_a3 * dg);

          f_x = fc * r_ij.x;
          f_y = fc * r_ij.y;
          f_z = fc * r_ij.z;

          /* add in this contribution */
          fj.x += f_x;
          fj.y += f_y;
          fj.z += f_z;
          f[i].x -= f_x;
          f[i].y -= f_y;
          f[i].z -= f_z;
          u += uc;
        }
        f[j] = fj;
      } /* end loop over all pairs of atoms */

    } /* end loop over all neighbors of this box */

  } /* end loop over all boxes */

  /* add in accumulated short-range potential */
  sys->u_short += u;

  return 0;
}



/******************************************************************************
 *
 * Evaluation routine for splittings with some odd power of (r/a).
 *
 *****************************************************************************/

int cell_interactions_odd(Mgrid *mg, MgridSystem *sys)
{
  const MgridCell *cell = mg->cell;
  const int32 *next = mg->next;
  const MD_Dvec *p = sys->pos;
  MD_Dvec *f = (sys->f_short ? sys->f_short : sys->f_elec);
  const double *q = sys->charge;
  const double a2 = mg->param.cutoff * mg->param.cutoff;
  const double a_1 = mg->inv_cutoff;
  const double a_2 = a_1 * a_1;
  MD_Dvec offset;         /* periodic displacement of nbr cell k */
  MD_Dvec r_ij;           /* r_ij=p[j]-p[i] */
  MD_Dvec pj, fj;         /* pj="temp p[j]", fj="temp f[j]" */
  double qj;              /* qj="temp q[j]" */
  double r, r2, r_1, r_2; /* r2=||r_ij||^2, r_1=1/r, r_2=1/r2 */
  double r_a, g, dg;      /* r_a=r/a, g=g(s), dg=(dg/ds) */
  double qc, uc, fc;      /* temp vars to accumulate mults */
  double f_x, f_y, f_z;   /* force in x, y, z directions */
  double u = 0.0;         /* accumulate potential energy */
  const int32 split = mg->param.split;
  const int32 ncells = mg->ncells;
  int32 nnbrs, ihead, jhead, i, j, k, n;


  /* loop over cells */
  for (n = 0;  n < ncells;  n++) {
    nnbrs = cell[n].nnbrs;
    jhead = cell[n].head;

    /* loop over all of the neighbors of this cell */
    for (k = 0;  k < nnbrs;  k++) {
      offset = cell[n].offset[k];
      ihead = cell[ cell[n].nbr[k] ].head;

      /* loop over all pairs of atoms */
      for (j = jhead;  j != -1;  j = next[j]) {

        /* subtract offset from p[j] is equivalent to adding it to p[i] */
        pj.x = p[j].x - offset.x;
        pj.y = p[j].y - offset.y;
        pj.z = p[j].z - offset.z;
        qj = q[j];
        fj = f[j];

        /* 0th neighbor is self-referential, must modify ihead */
        if (k == 0) ihead = next[j];

        for (i = ihead;  i != -1;  i = next[i]) {

          r_ij.x = pj.x - p[i].x;
          r_ij.y = pj.y - p[i].y;
          r_ij.z = pj.z - p[i].z;
          r2 = r_ij.x * r_ij.x + r_ij.y * r_ij.y + r_ij.z * r_ij.z;

          if (r2 >= a2) continue;
          ASSERT(i != j);

          r = sqrt(r2);
          r_1 = 1.0 / r;
          r_2 = r_1 * r_1;

          r_a = r * a_1;
          dgamma_odd(&g, &dg, r_a, split);

          qc = q[i] * qj;
          uc = qc * (r_1 - a_1 * g);
          fc = qc * r_1 * (r_2 + a_2 * dg);

          f_x = fc * r_ij.x;
          f_y = fc * r_ij.y;
          f_z = fc * r_ij.z;

          /* add in this contribution */
          fj.x += f_x;
          fj.y += f_y;
          fj.z += f_z;
          f[i].x -= f_x;
          f[i].y -= f_y;
          f[i].z -= f_z;
          u += uc;
        }
        f[j] = fj;
      } /* end loop over all pairs of atoms */

    } /* end loop over all neighbors of this box */

  } /* end loop over all boxes */

  /* add in accumulated short-range potential */
  sys->u_short += u;

  return 0;
}



/******************************************************************************
 *
 * Setup routines.
 *
 *****************************************************************************/

int geometric_hashing(Mgrid *mg, const MgridSystem *sys)
{
  MD_Dvec lo = mg->lo;
  const double inv_cellsize = mg->inv_cellsize;
  const MD_Dvec *pos = sys->pos;
  int32 *next = mg->next;
  MgridCell *cell = mg->cell;
  const int32 ncells = mg->ncells;
  const int32 ndimcells = mg->ndimcells;
  const int32 natoms = mg->param.natoms;
  int32 i, j, k, n, index;

  /* clear cells */
  for (n = 0;  n < ncells;  n++) {
    cell[n].head = -1;
    cell[n].cnt = 0;
  }

  /* place each atom in its cells */
  for (n = 0;  n < natoms;  n++) {

    /* determine i cell index */
    i = (int32) ((pos[n].x - lo.x) * inv_cellsize);
    j = (int32) ((pos[n].y - lo.y) * inv_cellsize);
    k = (int32) ((pos[n].z - lo.z) * inv_cellsize);

    /* atoms are expected to be within cubic domain */
    ASSERT(i >= 0 && i < ndimcells);
    ASSERT(j >= 0 && j < ndimcells);
    ASSERT(k >= 0 && k < ndimcells);
    index = (k * ndimcells + j) * ndimcells + i;

    /* insert atom into front of (i,j,k)th cell list */
    ASSERT(index >= 0 && index < ncells);
    next[n] = cell[index].head;
    cell[index].head = n;
    cell[index].cnt++;
  }

  return 0;
}


int setup_nbrcells(Mgrid *mg)
{
  MD_Dvec offset;
  const double len = mg->param.length;    /* length of cubic domain */
  MgridCell *cell = mg->cell;             /* array of cells */
  const int32 ncells = mg->ncells;        /* total number of cells */
  const int32 ndimcells = mg->ndimcells;  /* number cells along dimension */
  const int32 isperiodic = (mg->param.boundary == MGRID_PERIODIC);
  int32 i, j, k, ii, jj, kk, in, jn, kn, n, nn;

  /* grid cells should already be allocated */
  ASSERT(ndimcells > 0);
  ASSERT(ncells == ndimcells * ndimcells * ndimcells);
  ASSERT(cell != NULL);

  /* clear cell memory */
  memset(cell, 0, ncells * sizeof(MgridCell));

  /* loop through all cells */
  for (k = 0;  k < ndimcells;  k++) {
    for (j = 0;  j < ndimcells;  j++) {
      for (i = 0;  i < ndimcells;  i++) {

        /* determine index of this cell */
        n = (k * ndimcells + j) * ndimcells + i;
        ASSERT(n >= 0 && n < ncells);

        /* loop through neighbors of this cell */
        for (kn = 0;  kn <= 1;  kn++) {
          kk = k + kn;
          offset.z = 0.0;
          if (kk == ndimcells) {
            if (isperiodic) {
              kk = 0;
              offset.z = len;
            }
            else continue;
          }

          for (jn = (kn == 0 ? 0 : -1);  jn <= 1;  jn++) {
            jj = j + jn;
            offset.y = 0.0;
            if (jj == ndimcells) {
              if (isperiodic) {
                jj = 0;
                offset.y = len;
              }
              else continue;
            }
            else if (jj == -1) {
              if (isperiodic) {
                jj = ndimcells - 1;
                offset.y = -len;
              }
              else continue;
            }

            for (in = (kn == 0 && jn == 0 ? 0 : -1);  in <= 1;  in++) {
              ii = i + in;
              offset.x = 0.0;
              if (ii == ndimcells) {
                if (isperiodic) {
                  ii = 0;
                  offset.x = len;
                }
                else continue;
              }
              else if (ii == -1) {
                if (isperiodic) {
                  ii = ndimcells - 1;
                  offset.x = -len;
                }
                else continue;
              }

              /* index of neighbor cell */
              nn = (kk * ndimcells + jj) * ndimcells + ii;
              ASSERT(nn >= 0 && nn < ncells);

              /* store neighbor index and offset */
              cell[n].nbr[ cell[n].nnbrs ] = nn;
              cell[n].offset[ cell[n].nnbrs ] = offset;
              cell[n].nnbrs++;
            }
          }
        } /* end loops over all neighbors */

        /* first neighbor should be self */
        ASSERT(cell[n].nbr[0] == n);

      }
    }
  } /* end loops over all cells */

  return 0;
}
