/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/*
 * vdw.c
 *
 * Routines to evaluate cutoff nonbonded forces.
 */


#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include "force.h"
#include "data.h"
#include "helper.h"
#include "vdw.h"
#include "utilities.h"

#define NBOX 5

/*
 * prototypes for internal functions
 */
static MD_Errcode build_vdwtable(struct VDW_Tag *vdw, 
				 const MD_Atom_Param *atomprm,
				 const MD_Int natomprms,
				 const MD_Nbfix_Param *nbfixprm,
				 const MD_Int nnbfixprms);

static void eval_vdw(struct VDW_Tag *vdw);
static MD_Double eval_vdwpair(struct VDW_Tag *vdw, MD_Int i, MD_Int j, 
			      MD_Double rlen2,
			      MD_Dvec *r_ij, MD_Dvec *f_i, MD_Dvec *f_j);


MD_Errcode vdw_init(struct VDW_Tag* vdw, struct VDW_init_Tag *init_data)
{
  struct LinkCell_init_Tag linkcell_init_data;

  vdw->ppos         = init_data->ppos;
  vdw->patom        = init_data->patom;
  vdw->pexcllist    = init_data->pexcllist;
  vdw->pscaled14    = init_data->pscaled14;
  vdw->natomprms    = init_data->natomprms;
  vdw->natoms       = init_data->natoms;
  vdw->systemsize   = init_data->systemsize;
  vdw->cutoff       = init_data->cutoff;
  vdw->switchdist   = init_data->switch_dist;
  vdw->is_switching = init_data->is_switching;

  if (build_vdwtable(vdw, 
		     init_data->data->atomprm, 
		     init_data->data->natomprms,
		     init_data->data->nbfixprm, 
		     init_data->data->nnbfixprms)) {
    fprintf(stderr, "cannot build up vdw table\n");
    return MD_FAIL;
  }

  assert(NULL != vdw->pexcllist);  /* by construction */
  vdw->force = my_malloc(vdw->natoms * sizeof(MD_Dvec), "force");
  /* calculate vdw force constants */
  vdw->cutoff2 = vdw->cutoff * vdw->cutoff;
  vdw->inv_cutoff2 = 1.0 / vdw->cutoff2;
  vdw->switchdist2 = vdw->switchdist * vdw->switchdist;
  if (vdw->is_switching) {
    assert(vdw->cutoff2 > vdw->switchdist2);
    vdw->sw_denom = 1.0 / (vdw->cutoff2 - vdw->switchdist2);
    vdw->sw_denom = vdw->sw_denom * vdw->sw_denom * vdw->sw_denom;
    vdw->sw_denom_times_four = 4.0 * vdw->sw_denom;
  }
  printf("vdw parameters:\n");
  printf("  cutoff distance: %f\n", vdw->cutoff);
  if (vdw->is_switching) {
    printf("  switch distance: %f\n", vdw->switchdist);
  } else {
    printf("  no switching\n");
  }

  linkcell_init_data.nxcell = NBOX;
  linkcell_init_data.nycell = NBOX;
  linkcell_init_data.nzcell = NBOX;
  linkcell_init_data.systemsize = vdw->systemsize;
  MD_vec_mul(vdw->systemsize, -0.5, linkcell_init_data.min);
  linkcell_init_data.cutoff = vdw->cutoff;
  linkcell_init_data.natoms = vdw->natoms;
  
  vdw->linkcell = my_malloc(sizeof(struct LinkCell_Tag), "linkcell");
  if (linkcell_init(vdw->linkcell, &linkcell_init_data)) {
    fprintf(stderr, "failed to init the linkcell structure\n");
    return MD_FAIL;
  }

  return OK;
}



MD_Errcode vdw_destroy(struct VDW_Tag *vdw)
{
  free(vdw->force);
  free(vdw->vdwtable);
  linkcell_destroy(vdw->linkcell);
  memset(vdw, 0, sizeof(struct VDW_Tag)); /* avoid seg fault from double free */

  return OK;
}


/*
 * simple nonbonded cutoff algorithm:
 * always hash before doing a box force evaluation
 */
MD_Errcode vdw_compute(struct VDW_Tag *vdw)
{
  if (linkcell_hash_atoms(vdw->linkcell, vdw->ppos))  return MD_FAIL;  
  /* printf("evaluate vdw force\n"); */
  eval_vdw(vdw);

  return OK;
}

MD_Double vdw_get_energy(const struct VDW_Tag *vdw)
{
  return vdw->energy;
}


const MD_Dvec*  vdw_get_force(const struct VDW_Tag *vdw)
{
  return vdw->force;
}



/*
 * evaluate vdw forces 
 */
void eval_vdw(struct VDW_Tag *vdw)
{
  const struct Cell_Tag *cell = vdw->linkcell->cell;
  const MD_Double cutoff2 = vdw->cutoff2;
  const MD_Dvec systemsize = vdw->systemsize;
  const MD_Int *head = vdw->linkcell->head; 
  const MD_Int *list = vdw->linkcell->list;
  const MD_Int numcell = vdw->linkcell->numcell;
  const MD_Int natoms = vdw->natoms;
  const MD_Dvec *pos = vdw->ppos;
  MD_Dvec *f = vdw->force;
  MD_Dvec r_ij;
  MD_Double rlen2, pe, pe_vdw = 0.0;
  MD_Int i, j, k, nn;

  for (i=0; i<natoms; i++) {f[i].x = 0.0; f[i].y=0.0; f[i].z=0.0;}

  for (k = 0;  k < numcell;  k++) { 
    i = head[k];  /* first atom in cell k */ 
    while(i >= 0) {     /* pairs in same cell */
      j = list[i];      /* other atoms in this cell */
      while(j >= 0) {
	/*printf("evaluate vdw between %d, %d in same cell.\n", i,j);*/
	MD_vec_substract(pos[j], pos[i], r_ij);
        SIMPLE_BOUND_VEC(r_ij, systemsize); /* it can be only one cell */
        rlen2 = MD_vec_dot(r_ij, r_ij);
        if (rlen2 < cutoff2) {
	  pe = eval_vdwpair(vdw, i, j, rlen2, &r_ij, f+i, f+j);
	  /*
	  if (pe > 0) {
	    printf("atom %d: %f, %f, %f\n", i, xi, yi, zi);
	    printf("atom %d: %f, %f, %f\n", j, pos[j].x, pos[j].y, pos[j].z);
	    printf("r_ij = (%f, %f, %f)\n", r_ij.x, r_ij.y, r_ij.z);
	    printf("pe = %f\n", pe);
	  }
	  */
          pe_vdw += pe;
	  /*printf("same cell %d, i=%d, j=%d, pe_vdw=%f\n", k, i,j, pe_vdw);*/
        }
	j = list[j]; /* next atom */
      } /* j loop over */
      /* atoms in neighbor cell */
      for (nn = 0;  nn < cell[k].numnbrs;  nn++) {
	j = head[cell[k].nbr[nn]];
	while (j >= 0) {
 	  /*printf("evaluate vdw between %d, %d in different cell.\n", i,j);*/
	  MD_vec_substract(pos[j], pos[i], r_ij);
	  SIMPLE_BOUND_VEC(r_ij, systemsize);
          rlen2 = MD_vec_dot(r_ij, r_ij);
	  if (rlen2 < cutoff2) {
	    pe_vdw += eval_vdwpair(vdw, i, j, rlen2, &r_ij, f+i, f+j);
	    /*printf("neibr cell %d, i=%d, j=%d, pe_vdw=%f\n", k,i,j,pe_vdw);*/
	  }
	  j = list[j];
	}
      } /* nn loop is over */
      i = list[i];
    } /* i loop over */
  } /* k (cell) loop over */

  vdw->energy = pe_vdw;
  /*printf("vdw energy is %f\n", pe_vdw);*/

}



/*
 * evaluate and accumulate the vdw forces between a pair of atoms
 *          
 * U(r) =  A / r^12 - B / r^6
 *
 * assume that the pair of atoms are within the cutoff distance
 */
MD_Double eval_vdwpair(struct VDW_Tag *vdw, MD_Int i, MD_Int j, 
		       MD_Double rlen2,
		       MD_Dvec *r_ij, MD_Dvec *f_i, MD_Dvec *f_j)
{
  const MD_Atom *atom = vdw->patom;
  const MD_Double cutoff2 = vdw->cutoff2;
  const MD_Double switchdist2 = vdw->switchdist2; 
  const MD_Int *excl;
  const MD_Double *entry;
  MD_Int **excllist = vdw->pexcllist; /* guranteed not null */
  MD_Int **scaled14 = vdw->pscaled14;
  const MD_Int natomprms = vdw->natomprms;
  MD_Dvec f_nb;
  MD_Double inv_rlen2, inv_rlen6, inv_rlen12;
  MD_Double aterm_vdw = 0.0, bterm_vdw = 0.0, ffac_vdw;
  MD_Double swa, swb, swv, dswv;
  MD_Double evdw;

#ifdef DEBUG_VDW
  printf("evaluate vdw force between %d and %d\n", i,j);
#endif

  /* check to see if this pair is excluded */
  if (NULL != excllist) {
    excl = excllist[i];
    while (*excl < j) {
      excl++;
    }
    if (j == *excl) return 0.0; /* excluded */ 
  }

  /* compute simple functions of rlen2 */
  inv_rlen2 = 1.0 / rlen2;
  inv_rlen6 = inv_rlen2 * inv_rlen2 * inv_rlen2;
  inv_rlen12 = inv_rlen6 * inv_rlen6;

  /* choose correct van der Waals parameters */
  entry = vdw->vdwtable + 4 * (atom[i].prm * natomprms 
	        	       + atom[j].prm);

  /* determine whether this is a 1-4 interaction */
  if (scaled14 == NULL) {  
    aterm_vdw = entry[ATERM] * inv_rlen12;
    bterm_vdw = entry[BTERM] * inv_rlen6;
  } else {
    fprintf(stderr, "error, should not have scaled14\n");
    exit(1);
  }
#if 0  /* do not deal with scaled 1-4 yet */
 else {  
    for (excl = scaled14[i];  *excl < j;  excl++) ;
    if (j != *excl) {
      aterm_vdw = entry[ATERM] * inv_rlen12;
      bterm_vdw = entry[BTERM] * inv_rlen6;
    }
    else { 
      aterm_vdw = entry[A_14] * inv_rlen12;
      bterm_vdw = entry[B_14] * inv_rlen6;
      c_elec *= epsilon14;
    }
  }
#endif

  /* van der Waals computations */
  evdw = aterm_vdw - bterm_vdw;   /* v(r) */
  ffac_vdw = (12.0 * aterm_vdw - 6.0 * bterm_vdw) * inv_rlen2; /* -v'(r)/r */
  if (vdw->is_switching && rlen2 > switchdist2) {
    /* when r_s < r < r_c (in switching region), the potential is modified
     * from v(r) --> v(r) * s(r), when v(r) is the vdw potential, s(r) is
     * the switching function:
     *             (r_c^2 - r^2)^2 * (r_c^2 + 2*r^2 - 3*r_s^2)
     *     s(r) = ----------------------------------------------
     *                          (r_c^2 - r_s^2)^3
     * note that:
     *     s(r_s) = 1,   s(r_c) = 0,   s'(r_s) = 0,   s'(r_c) = 0
     * the force is modified accordingly.
     */
    swa = cutoff2 - rlen2;
    swb = swa * (cutoff2 + 2.0 * rlen2 - 3.0 * switchdist2);
    swv = swa * swb * vdw->sw_denom;  /* switch function: s(r) */
    dswv = (swa * swa - swb) * vdw->sw_denom_times_four; /* s'(r) / r */
    ffac_vdw = ffac_vdw * swv - evdw * dswv;   /* -(v'*s+s*v')/r */
    evdw *= swv;   
  }

  /* find and accumulate nonbonded forces */
  /* note that r_ij = r_j - r_i */
  MD_vec_mul((*r_ij), ffac_vdw, f_nb);  /* force = (-v'/r)  * vector(r) */
  MD_pvec_substract(f_i, &f_nb, f_i);
  MD_pvec_add(f_j, &f_nb, f_j);

  return evdw;
}


/*
 * build the van der Waals parameter table
 *
 * table is a "square" symmetric matrix, dimension (natomprms * natomprms)
 * each entry of matrix contains A, B, A_14, B_14 parameters
 * matrix is indexed by atom "types" (0..natomprms-1)
 *
 * V_vdw(r) = ATERM / r^12 - BTERM / r^6
 *
 * vdwtable is stored as one-dimensional array
 * index for (i,j) atom pair interaction is:  4 * (i * natomprms + j)
 */
MD_Errcode build_vdwtable(struct VDW_Tag *vdw, 
			  const MD_Atom_Param *atomprm,
			  const MD_Int natomprms,
			  const MD_Nbfix_Param *nbfixprm,
			  const MD_Int nnbfixprms)
{
  MD_Double *vdwtable, *ij_entry, *ji_entry;
  MD_Double neg_emin, rmin, neg_emin14, rmin14;
  MD_Int  i, j, k;

  vdwtable = calloc((size_t) (4 * natomprms * natomprms), sizeof(MD_Double));
  if (vdwtable == NULL) {
    printf("build_vdwtable call to calloc");
    return MD_ERR_MEMALLOC;
  }

  /* compute each table entry given separate i and j atom params */
  for (i = 0;  i < natomprms;  i++) {
    for (j = i;  j < natomprms;  j++) {
      ij_entry = vdwtable + 4 * (i * natomprms + j);
      ji_entry = vdwtable + 4 * (j * natomprms + i);

      /* compute vdw A and B coefficients for atom type ij interaction */
      neg_emin = sqrt(atomprm[i].emin * atomprm[j].emin);
      rmin = 0.5 * (atomprm[i].rmin + atomprm[j].rmin);
      neg_emin14 = sqrt(atomprm[i].emin14 * atomprm[j].emin14);
      rmin14 = 0.5 * (atomprm[i].rmin14 + atomprm[j].rmin14);

      /* raise rmin and rmin14 to 6th power */
      rmin *= rmin * rmin;
      rmin *= rmin;
      rmin14 *= rmin14 * rmin14;
      rmin14 *= rmin14;

      /* set ij entry and its transpose */
      ij_entry[ATERM]= ji_entry[ATERM]= neg_emin * rmin * rmin;
      ij_entry[BTERM]= ji_entry[BTERM]= 2.0 * neg_emin * rmin;
      ij_entry[A_14] = ji_entry[A_14] = neg_emin14 * rmin14 * rmin14;
      ij_entry[B_14] = ji_entry[B_14] = 2.0 * neg_emin14 * rmin14;
    }
  }

  /* now go back and update entries for nbfix params */
  for (k = 0;  k < nnbfixprms;  k++) {
    i = nbfixprm[k].prm[0];
    j = nbfixprm[k].prm[1];

    ij_entry = vdwtable + 4 * (i * natomprms + j);
    ji_entry = vdwtable + 4 * (j * natomprms + i);

    /* compute vdw A and B coefficients for this fixed type interaction */
    neg_emin = -nbfixprm[k].emin;
    rmin = nbfixprm[k].rmin;
    neg_emin14 = -nbfixprm[k].emin14;
    rmin14 = nbfixprm[k].rmin14;

    /* raise rmin and rmin14 to 6th power */
    rmin *= rmin * rmin;
    rmin *= rmin;
    rmin14 *= rmin14 * rmin14;
    rmin14 *= rmin14;

    /* set ij entry and its transpose */
    ij_entry[ATERM]= ji_entry[ATERM]= neg_emin * rmin * rmin;
    ij_entry[BTERM]= ji_entry[BTERM]= 2.0 * neg_emin * rmin;
    ij_entry[A_14] = ji_entry[A_14] = neg_emin14 * rmin14 * rmin14;
    ij_entry[B_14] = ji_entry[B_14] = 2.0 * neg_emin14 * rmin14;
  }

  vdw->vdwtable = vdwtable;
  return OK;
}


