/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/* standard implementation of the Ewald sum */


#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include <assert.h>
#include "timer.h"
#include "unit.h"
#include "constant.h"
#include "utilities.h"
#include "helper.h"
#include "standEwald_rec.h"
#include "data.h"
#include "force.h"


MD_Errcode init_rec(struct standEwald_Tag *se)
{
  const MD_Dvec a1 = se->a1;
  const MD_Dvec a2 = se->a2;
  const MD_Dvec a3 = se->a3;
  MD_Dvec b1, b2, b3;
  const MD_Double beta = se->beta;
  const MD_Double kcut = se->kcut;
  MD_Dvec *reclatt;
  MD_Double *recU; /* reciprocal potential */
  MD_Double vol;   /* not quite volum, because it can be negative */
  const MD_Double symfac = 2.0;
  const MD_Int natoms = se->natoms;
  MD_Int k1max, k2max, k3max;
  MD_Int i,j,k;
  MD_Int nreclatt;

  se->recforce = my_calloc((size_t)natoms, sizeof(*se->recforce),
			   "reciprocal force");
  se->neg_recG1q = my_calloc((size_t)natoms*3, sizeof(*se->neg_recG1q), 
			     "-recG1q");

  vol = (a1.x) * ((a2.y) * (a3.z) - (a2.z) * (a3.y))
      + (a1.y) * ((a2.z) * (a3.x) - (a2.x) * (a3.z))
      + (a1.z) * ((a2.x) * (a3.y) - (a2.y) * (a3.x));

  /* compute the reciprocal vectors, dot(a_i, b_j) = 2*pi*delta_ij */
  b1.x = (a2.y * a3.z - a2.z * a3.y) / vol * twoPi;
  b1.y = (a2.z * a3.x - a2.x * a3.z) / vol * twoPi;
  b1.z = (a2.x * a3.y - a2.y * a3.x) / vol * twoPi;
  b2.x = (a3.y * a1.z - a3.z * a1.y) / vol * twoPi;
  b2.y = (a3.z * a1.x - a3.x * a1.z) / vol * twoPi;
  b2.z = (a3.x * a1.y - a3.y * a1.x) / vol * twoPi;
  b3.x = (a1.y * a2.z - a1.z * a2.y) / vol * twoPi;
  b3.y = (a1.z * a2.x - a1.x * a2.z) / vol * twoPi;
  b3.z = (a1.x * a2.y - a1.y * a2.x) / vol * twoPi;
  se->b1 = b1;
  se->b2 = b2;
  se->b3 = b3;

  vol = vol < 0 ? (-vol) : vol;
  
  se->kmax[0] = k1max = (MD_Int) (kcut / MD_vecLen(b1));
  se->kmax[1] = k2max = (MD_Int) (kcut / MD_vecLen(b2));
  se->kmax[2] = k3max = (MD_Int) (kcut / MD_vecLen(b3));

  reclatt  = my_calloc((size_t) ((k1max+1) * (2*k2max+1) * (2*k3max+1)), 
		       sizeof(MD_Dvec), "reclatt");
  recU = my_calloc((size_t) ((k1max+1) * (2*k2max+1) * (2*k3max+1)), 
		   sizeof(MD_Double), "recU");
  nreclatt = 0;
  for (i = 0; i <= k1max; i++) {   /* use reversal symmetry */
    for (j = (0==i)?0:-k2max; j <= k2max; j++) {
      for (k = (0==i&&0==j)?1:-k3max; k <= k3max; k++) {
	MD_Dvec rec;
	MD_Double reclensqr;
	rec.x = i*b1.x + j*b2.x + k*b3.x;
	rec.y = i*b1.y + j*b2.y + k*b3.y;
	rec.z = i*b1.z + j*b2.z + k*b3.z;
	reclensqr = MD_vec_dot(rec, rec);
	if (reclensqr > kcut*kcut) continue;
	reclatt[nreclatt] = rec;
        recU[nreclatt] = symfac*4.0*Pi * exp(-reclensqr/(4.0*beta*beta)) 
	                 / (reclensqr*vol);
	/*printf("%f,%f,%f, %f\n", reclatt[3*nreclatt+X],
           reclatt[3*nreclatt+Y], reclatt[3*nreclatt+Z], recU[nreclatt]);*/
	nreclatt ++;
      }
    }
  }
  se->nreclatt = nreclatt;
  se->reclatt = realloc(reclatt, (size_t)nreclatt * sizeof(MD_Dvec));
  se->recU = realloc(recU, (size_t)nreclatt * sizeof(MD_Double));
  assert(NULL != se->reclatt && NULL != se->recU);
  printf("  nreclatt = %d,  ", se->nreclatt);
  printf("  Kmax = (%d, %d, %d)\n", k1max, k2max, k3max);

  se->coskr = my_calloc((size_t) (se->nreclatt * natoms), 
			sizeof(MD_Double), "coskr");
  se->sinkr = my_calloc((size_t) (se->nreclatt * natoms), 
			sizeof(MD_Double), "sinkr");
  se->reS = my_calloc((size_t)se->nreclatt, sizeof(MD_Double),"reS");
  se->imS = my_calloc((size_t)se->nreclatt, sizeof(MD_Double),"imS");
  if (se->has_induced_dipole) {
    se->recg2d = my_calloc((size_t)(3*natoms), sizeof(MD_Double), "recg2d");
    se->recg2q = my_calloc((size_t)(NIND2*natoms), sizeof(MD_Double), 
			   "recg2q");
  }

  se->work = my_calloc((size_t) (3*natoms), sizeof(MD_Double), "work");

  return OK;
}


MD_Errcode destroy_rec(struct standEwald_Tag *se)
{
  free(se->recforce);
  free(se->neg_recG1q);
  free(se->reclatt);
  free(se->recU);
  free(se->coskr);
  free(se->sinkr);
  free(se->reS);
  free(se->imS);
  free(se->work);

  if (se->has_induced_dipole) {
    free(se->recg2d);
    free(se->recg2q);
  }

  return OK;
}


MD_Double calc_kcut(MD_Double beta, MD_Double errTol)
{
  MD_Double low = 0.0, s = 4.0, high;
  MD_Int i;

  while(exp(-s*s) > errTol) s += s;

  high = s;
  for (i = 0; i < 100; i++) {
    s = 0.5 * (high + low);
    if (exp(-s*s) > errTol) low = s;
    else high = s;
  }
  
  return 2.0*s*beta;
}


/* compute: 
 *   1. charge--charge energy from the reciprocal sum, 
 *   2. ADD reciprocal part of -G1q to the direct part, ==> get full -G1q
 *   3. internal representantation of the reciprocal part of G2/G3 matrix:
 *            coskr and sinkr matrix
 *            structure factor (real,imaginary)
 */
void dipole_rec_setup(struct standEwald_Tag *se)
{
  MD_Dvec *f = se->recforce;
  MD_Double *neg_recG1q = se->neg_recG1q;
  const MD_Dvec *pos = se->ppos;
  const MD_Dvec *reclatt = se->reclatt;
  const MD_Double *charge = se->charge;
  const MD_Double *recU = se->recU;
  const MD_Int nreclatt = se->nreclatt;
  const MD_Int natoms = se->natoms;
  MD_Double * const reS = se->reS;
  MD_Double * const imS = se->imS;
  MD_Double * sinkr = se->sinkr;
  MD_Double * coskr = se->coskr;
  register MD_Double realS, imaginaryS;  
  MD_Double urs, uis;
  MD_Dvec ukrs, ukis;
  MD_Double Eqqrec;
  register MD_Double kdotr;
  register MD_Dvec veck;
  MD_Int i, k;

  Eqqrec = 0.0;
  memset(neg_recG1q, 0, 3*natoms*sizeof(*neg_recG1q));
  for (k = 0; k < nreclatt; k++) {
    veck = reclatt[k];
    realS = imaginaryS = 0.0;
    for (i = 0; i < natoms; i++) {
      kdotr = MD_vec_dot(veck, pos[i]);
      realS += charge[i] * (coskr[i] = cos(kdotr));
      imaginaryS += charge[i] * (sinkr[i] = sin(kdotr));
    }
    reS[k] = realS;
    imS[k] = imaginaryS;
    /* reciprocal charge-charge energy */
    Eqqrec += recU[k] * (realS*realS + imaginaryS*imaginaryS); 
    urs = recU[k] * realS;
    uis = recU[k] * imaginaryS;
    MD_vec_mul(veck, urs, ukrs);
    MD_vec_mul(veck, uis, ukis);
    for (i = 0; i < natoms; i++) {
      neg_recG1q[3*i+X] += ukrs.x*sinkr[i] - ukis.x*coskr[i];
      neg_recG1q[3*i+Y] += ukrs.y*sinkr[i] - ukis.y*coskr[i];
      neg_recG1q[3*i+Z] += ukrs.z*sinkr[i] - ukis.z*coskr[i];
    }
    coskr += natoms;
    sinkr += natoms;
  }

  se->recEnergy = Eqqrec * 0.5;
#ifdef DEBUG_STANDEWALD
  printf("rec q-q energy = %f\n", se->recEnergy);
#endif

  for (i=0; i<natoms; i++) {
    f[i].x = charge[i] * neg_recG1q[3*i+X];
    f[i].y = charge[i] * neg_recG1q[3*i+Y];
    f[i].z = charge[i] * neg_recG1q[3*i+Z];
  }

}


/* compute the reciprocal contribution to the Matrix vector product G2 * v, 
 * the  result is, the vector G2v is added by G2*v,
 * so we assume G2v is already initialized 
 */
void add_recG2v(const struct standEwald_Tag *se, const MD_Double *v, 
		MD_Double *G2v)
{
  const MD_Double *recU  = se->recU;
  const MD_Dvec *reclatt = se->reclatt;
  const MD_Double *sinkr = se->sinkr;
  const MD_Double *coskr = se->coskr;
  const MD_Double *vi;
  MD_Double *pG2v;
  MD_Double kx, ky, kz;
  MD_Double ukx, uky, ukz;
  register MD_Double cvx, cvy, cvz, svx, svy, svz;
  MD_Double cfac, sfac;
  MD_Double tmp;
  const MD_Int natoms = se->natoms;
  const MD_Int nreclatt = se->nreclatt;
  MD_Int k, i;

  for (k = 0; k < nreclatt; k++) {
    cvx = cvy = cvz = 0.0;
    svx = svy = svz = 0.0;
    vi = v;
    for (i = 0; i < natoms; i++) {
      cvx += vi[X] * coskr[i];
      cvy += vi[Y] * coskr[i];
      cvz += vi[Z] * coskr[i];
      svx += vi[X] * sinkr[i];
      svy += vi[Y] * sinkr[i];
      svz += vi[Z] * sinkr[i];
      vi += 3; 
    }
    kx = reclatt[k].x;
    ky = reclatt[k].y;
    kz = reclatt[k].z;
    cfac = cvx*kx + cvy*ky + cvz*kz;
    sfac = svx*kx + svy*ky + svz*kz;
    ukx = recU[k] * kx;
    uky = recU[k] * ky;
    ukz = recU[k] * kz;
    pG2v = G2v;
    for (i = 0; i < natoms; i++) {
      tmp = coskr[i]*cfac + sinkr[i]*sfac;
      pG2v[X] += ukx * tmp;
      pG2v[Y] += uky * tmp;
      pG2v[Z] += ukz * tmp;
      pG2v += 3;
    }
    coskr += natoms;
    sinkr += natoms;
  }

}


void add_rec_pseudores(struct standEwald_Tag *se,  
		       const MD_Double *v, MD_Int flag,
		       MD_Double *pseudores)
{
  MD_Double *tmp = se->work;
  MD_Double *neg_recG1q = se->neg_recG1q;
  MD_Int n = se->natoms * 3;
  MD_Int i;
  
  ASSERT(NULL != tmp);
  memset(tmp, 0, n*sizeof(*tmp));
  add_recG2v(se, v, tmp);
  if (flag) {
    for (i=0; i<n; i++) pseudores[i] += neg_recG1q[i] - tmp[i];
  } else {
    for (i=0; i<n; i++) pseudores[i] -= tmp[i];
  }
}


/* compute reciprocal contribution to sum_j G2_{i\alpha,j\beta} * q_j,
 * and store the result in se->recg2q
 * recG2: 3N*3N, q: N*1, G2q: N*3*3 */
void compute_recG2q(struct standEwald_Tag *se)
{
  const MD_Dvec *reclatt = se->reclatt;
  const MD_Double *recU  = se->recU;
  const MD_Double *sinkr = se->sinkr;
  const MD_Double *coskr = se->coskr;
  const MD_Double *reS   = se->reS;
  const MD_Double *imS   = se->imS;
  MD_Double * const recG2q = se->recg2q;
  MD_Double *g2q;
  MD_Double kx, ky, kz;
  MD_Double ukxx, ukxy, ukxz, ukyy, ukyz, ukzz;
  MD_Double resk, imsk;
  MD_Double tmp;
  const MD_Int natoms = se->natoms;
  const MD_Int nreclatt = se->nreclatt;
  const MD_Int ng2q = NIND2 * natoms;
  MD_Int k, i;

  for (i = 0; i < ng2q; i++) recG2q[i] = 0.0;
  
  for (k = 0; k < nreclatt; k++) {
    kx = reclatt[k].x;
    ky = reclatt[k].y;
    kz = reclatt[k].z;
    ukxx = recU[k] * kx * kx;
    ukxy = recU[k] * kx * ky;
    ukxz = recU[k] * kx * kz;
    ukyy = recU[k] * ky * ky;
    ukyz = recU[k] * ky * kz;
    ukzz = recU[k] * kz * kz;
    resk = reS[k];
    imsk = imS[k];
    g2q = recG2q;
    for (i = 0; i < natoms; i++) {
      tmp = coskr[i]*resk + sinkr[i]*imsk;
      g2q[XX] += ukxx * tmp;
      g2q[XY] += ukxy * tmp;
      g2q[XZ] += ukxz * tmp;
      g2q[YY] += ukyy * tmp;
      g2q[YZ] += ukyz * tmp;
      g2q[ZZ] += ukzz * tmp;
      g2q += NIND2;
    }
    coskr += natoms;
    sinkr += natoms;
  }

}


/* add the force from dipole--dipole interactions */
void add_recddforce(struct standEwald_Tag* se)
{
  MD_Dvec * const force = se->recforce;
  const MD_Dvec * const reclatt = se->reclatt;
  const MD_Double * const recU = se->recU;
  const MD_Double *sinkr = se->sinkr;
  const MD_Double *coskr = se->coskr;
  const MD_Double * const dipole = dsolver_get_dipole(se->dsolver);
  MD_Double kx, ky, kz;
  MD_Double ukx, uky, ukz;
  MD_Double dxckr, dyckr, dzckr, dxskr, dyskr, dzskr;
  MD_Double tmp, sfac, cfac;
  const MD_Int natoms = se->natoms;
  const MD_Int nreclatt = se->nreclatt;
  MD_Int k, i, ix, iy, iz;

  for (k = 0; k < nreclatt; k++) { 
    kx = reclatt[k].x;
    ky = reclatt[k].y;
    kz = reclatt[k].z;
    ukx = recU[k] * kx;
    uky = recU[k] * ky;
    ukz = recU[k] * kz;
    dxckr = dyckr = dzckr = 0.0;
    dxskr = dyskr = dzskr = 0.0;
    ix = 0; iy = 1; iz = 2;
    for (i = 0; i < natoms; i++) {
      dxckr += dipole[ix] * coskr[i];
      dyckr += dipole[iy] * coskr[i];
      dzckr += dipole[iz] * coskr[i];
      dxskr += dipole[ix] * sinkr[i];
      dyskr += dipole[iy] * sinkr[i];
      dzskr += dipole[iz] * sinkr[i]; 
      ix+=3; iy+=3; iz+=3; 
    }
    sfac = kx*dxckr + ky*dyckr + kz*dzckr;
    cfac = kx*dxskr + ky*dyskr + kz*dzskr;
    ix = 0; iy = 1; iz = 2;
    for (i = 0; i < natoms; i++) {
      tmp = (sfac * sinkr[i] - cfac * coskr[i])
          * (dipole[ix]*kx + dipole[iy]*ky + dipole[iz]*kz);  
      force[i].x += ukx * tmp;
      force[i].y += uky * tmp;
      force[i].z += ukz * tmp;
      ix+=3; iy+=3; iz+=3;
    }
    coskr += natoms;
    sinkr += natoms;
  }

}


/*
 * compute dipole-related reciprocal force, 
 */
void compute_rec_dipole_force(struct standEwald_Tag *se)
{
  MD_Dvec * const f = se->recforce;
  const MD_Double *d = dsolver_get_dipole(se->dsolver);
  const MD_Double *q = se->charge;
  MD_Double *g2d = se->recg2d;
  MD_Double *g2q = se->recg2q;
  const MD_Int natoms = se->natoms;
  const MD_Int ng2d = 3*natoms;
  MD_Int i;

  /* compute charge--dipole force. */
  for (i = 0; i < ng2d; i++) g2d[i] = 0.0;
  add_recG2v(se, d, g2d);
  compute_recG2q(se);
  for (i = 0; i < natoms; i++) {
    f[i].x -= q[i] * g2d[X] - (g2q[XX]*d[X] + g2q[XY]*d[Y] + g2q[XZ]*d[Z]);
    f[i].y -= q[i] * g2d[Y] - (g2q[YX]*d[X] + g2q[YY]*d[Y] + g2q[YZ]*d[Z]);
    f[i].z -= q[i] * g2d[Z] - (g2q[ZX]*d[X] + g2q[ZY]*d[Y] + g2q[ZZ]*d[Z]);
    d+=3; g2d+=3; g2q+=NIND2;
  }

  add_recddforce(se);  /* compute dipole--dipole force */
#ifdef DEBUG_STANDEWALD
  printf("rec q-q energy = %f\n", se->recEnergy);
#endif

  se->recEnergy -= 0.5 * DOT(dsolver_get_dipole(se->dsolver),
			     se->neg_recG1q, 3 * natoms);
#ifdef DEBUG_STANDEWALD
  printf("rec total energy = %f\n", se->recEnergy);
#endif

}


/* compute: 
 *   1. charge--charge energy from the reciprocal sum, 
 *   2. ADD reciprocal part of -G1q to the direct part, ==> get full -G1q
 *   3. internal representantation of the reciprocal part of G2/G3 matrix:
 *            coskr and sinkr matrix
 *            structure factor (real,imaginary)
 */
void charge_rec_setup(struct standEwald_Tag *se)
{
  MD_Dvec *f = se->recforce;
  MD_Double *neg_recG1q = se->neg_recG1q;
  const MD_Dvec *pos = se->ppos;
  const MD_Dvec *reclatt = se->reclatt;
  const MD_Double *charge = se->charge;
  const MD_Double *recU = se->recU;
  MD_Double * const reS = se->reS;
  MD_Double * const imS = se->imS;
  MD_Double * sinkr = se->sinkr;
  MD_Double * coskr = se->coskr;
  MD_Double kx, ky, kz;  
  MD_Double realS, imaginaryS;  
  MD_Double ukxrs, ukxis, ukyrs, ukyis, ukzrs, ukzis;
  MD_Double Eqqrec, kdotr;
  const MD_Int nreclatt = se->nreclatt;
  const MD_Int natoms = se->natoms;
  MD_Int i, k;

  Eqqrec = 0.0;
  memset(neg_recG1q, 0, 3*natoms*sizeof(*neg_recG1q));
  for (k = 0; k < nreclatt; k++) {
    kx = reclatt[k].x;
    ky = reclatt[k].y;
    kz = reclatt[k].z;
    realS = imaginaryS = 0.0;
    for (i = 0; i < natoms; i++) {
      kdotr = kx*pos[i].x + ky*pos[i].y + kz*pos[i].z;
      realS += charge[i] * (coskr[i] = cos(kdotr));
      imaginaryS += charge[i] * (sinkr[i] = sin(kdotr));
    }
    reS[k] = realS;
    imS[k] = imaginaryS;
    /* reciprocal charge-charge energy */
    Eqqrec += recU[k] * (realS*realS + imaginaryS*imaginaryS); 
    ukxrs = recU[k] * kx * realS;
    ukxis = recU[k] * kx * imaginaryS;
    ukyrs = recU[k] * ky * realS;
    ukyis = recU[k] * ky * imaginaryS;
    ukzrs = recU[k] * kz * realS;
    ukzis = recU[k] * kz * imaginaryS;
    for (i = 0; i < natoms; i++) {
      neg_recG1q[3*i+X] += ukxrs*sinkr[i] - ukxis*coskr[i];
      neg_recG1q[3*i+Y] += ukyrs*sinkr[i] - ukyis*coskr[i];
      neg_recG1q[3*i+Z] += ukzrs*sinkr[i] - ukzis*coskr[i];
    }
    coskr += natoms;
    sinkr += natoms;
  }

  se->recEnergy = Eqqrec * 0.5;
#ifdef DEBUG_STANDEWALD
  printf("rec q-q energy: %20.15f\n", se->recEnergy);
#endif

  for (i=0; i<natoms; i++) {
    f[i].x = charge[i] * neg_recG1q[3*i+X];
    f[i].y = charge[i] * neg_recG1q[3*i+Y];
    f[i].z = charge[i] * neg_recG1q[3*i+Z];
  }

}



