/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */


#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include "dipole_poly.h"
#include "standEwald_dir.h"
#include "standEwald_rec.h"
#include "helper.h"
#include "unit.h"
#include "constant.h"


static void dp_compute_dir_dipole_force(struct Dipole_Poly_Tag *dp);
static void dp_compute_rec_dipole_force(struct Dipole_Poly_Tag *dp);
static void dp_add_recddforce(struct Dipole_Poly_Tag *dp);




MD_Errcode dipole_poly_init(struct Dipole_Poly_Tag *dp,
			                struct dipole_poly_init *init_data)
{
  assert(NULL != dp);
  assert(NULL != init_data);
  assert(init_data->e_init.has_induced_dipole); 

  dp->se = my_calloc(1, sizeof(struct standEwald_Tag), "stand Ewald");
  if (stdEw_init(dp->se, &(init_data->e_init))) {
    printf("failed to init standard Ewald module\n");
    return MD_FAIL;
  }

  dp->a = init_data->a;
  dp->b = init_data->b;
  dp->dipole = my_calloc((size_t)(dp->se->natoms * 3), 
			 sizeof (MD_Double), "dipole");
  dp->aneg_G1q = my_calloc((size_t)(dp->se->natoms * 3), 
			 sizeof(MD_Double), "aneg_G1q");

  return OK;
}

MD_Errcode dipole_poly_destroy(struct Dipole_Poly_Tag *dp)
{
  if (stdEw_destroy(dp->se)) {
    fprintf(stderr, "failed to destroy standard Ewald module\n");
    return MD_FAIL;
  }
  free(dp->se);

  free(dp->dipole);
  free(dp->aneg_G1q);

  return OK;
}


const MD_Double * dipole_poly_get_dipole(struct Dipole_Poly_Tag *dp)
{
  return dp->dipole;
}


MD_Errcode dipole_poly_compute(struct Dipole_Poly_Tag *dp)
{
  struct standEwald_Tag* se = dp->se;
  const MD_Double a = dp->a;
  const MD_Double b = dp->b;
  const MD_Double beta = se->beta;
  const MD_Double c = a * 4.0*beta*beta*beta*one_over_sqrtPi/3.0;
  const MD_Double *alpha = se->palpha;
  MD_Double *ndirG1q = se->neg_dirG1q;
  MD_Double *nrecG1q = se->neg_recG1q;
  MD_Double * const aneg_G1q = dp->aneg_G1q;
  MD_Double * const dipole = dp->dipole;
  const MD_Int natoms = se->natoms;
  const MD_Int nd = 3 * natoms;
  MD_Dvec * const f = se->force;
  MD_Double factor = COULOMB_SQR;
  MD_Int i;

  dipole_dir_setup(se);
  dipole_rec_setup(se);

  for (i=0; i<nd; i++) aneg_G1q[i] = alpha[i] * (ndirG1q[i] + nrecG1q[i]);

  stdEw_compute_pseudores(se, aneg_G1q, 0, dipole);
  for (i = 0; i < nd; i++) {
    dipole[i]= -a*alpha[i]*dipole[i] + (b - c*alpha[i]) * aneg_G1q[i];
  }

  se->energy = se->dirEnergy + se->recEnergy 
    + se->self_energy_qq - 0.5 * DOT(dipole, ndirG1q, nd)
    - 0.5 * DOT(dipole, nrecG1q, nd);
  se->energy *= factor;

  for (i = 0; i < natoms; i++) f[i].x=f[i].y=f[i].z=0.0;
  dp_compute_dir_dipole_force(dp);
  dp_compute_rec_dipole_force(dp);
  /* get unit right, just be stupid */
  for (i = 0; i < natoms; i++) MD_vec_mul(f[i], factor, f[i]);

  return OK;
}


/* similar to compute_dir_dipole_force */
void dp_compute_dir_dipole_force(struct Dipole_Poly_Tag *dp)
{
  struct standEwald_Tag *se = dp->se;
  MD_Dvec * const f = se->force;
  const MD_Int natoms = se->natoms;
  const MD_Double *q = se->charge;
  MD_Double qi, qj;
  const MD_Double *dipole = dp->dipole;
  const MD_Double *di, *dj;
  const MD_Double *g2ij, *g3ij;
  const MD_Double a = dp->a;
  const MD_Double * const aneg_G1q = dp->aneg_G1q;
  const MD_Double *ang1qi, *ang1qj;
  MD_Double **dirG2 = se->dirG2;
  MD_Double **dirG3 = se->dirG3;
  MD_Double dq[3];
  register MD_Dvec df;  /* dipole force */
  MD_Double s[6];
  MD_Int numneibrs=-1;
  const MD_Int *neibrlist=NULL;
  MD_Int i, j, jj;

  /* note that diagonal elements of G2 do not contribute */
  for (i = 0; i < natoms; i++) {
    qi = q[i];
    di = dipole + 3*i;
    g2ij = dirG2[i];
    g3ij = dirG3[i];
    numneibrs = se->numneibrs[i];
    neibrlist = se->neibrlist[i];
    ang1qi = aneg_G1q + 3*i;
    for (jj = 0; jj < numneibrs; jj++) { /* charge-dipole force */
      j = neibrlist[jj];
      ASSERT(0 <= j && j < natoms);
      qj = q[j];
      dj = dipole + 3*j;
      /* charge-dipole force */
      dq[X] = di[X] * qj - dj[X] * qi;
      dq[Y] = di[Y] * qj - dj[Y] * qi;
      dq[Z] = di[Z] * qj - dj[Z] * qi;
      df.x = g2ij[XX]*dq[X] + g2ij[XY]*dq[Y] + g2ij[XZ]*dq[Z];
      df.y = g2ij[YX]*dq[X] + g2ij[YY]*dq[Y] + g2ij[YZ]*dq[Z];
      df.z = g2ij[ZX]*dq[X] + g2ij[ZY]*dq[Y] + g2ij[ZZ]*dq[Z];
      /* dipole-dipole force */
      ang1qj = aneg_G1q + 3*j;
      s[XX] = g3ij[XXX]*ang1qj[X] + g3ij[XXY]*ang1qj[Y] + g3ij[XXZ]*ang1qj[Z];
      s[XY] = g3ij[XYX]*ang1qj[X] + g3ij[XYY]*ang1qj[Y] + g3ij[XYZ]*ang1qj[Z];
      s[XZ] = g3ij[XZX]*ang1qj[X] + g3ij[XZY]*ang1qj[Y] + g3ij[XZZ]*ang1qj[Z];
      s[YY] = g3ij[YYX]*ang1qj[X] + g3ij[YYY]*ang1qj[Y] + g3ij[YYZ]*ang1qj[Z];
      s[YZ] = g3ij[YZX]*ang1qj[X] + g3ij[YZY]*ang1qj[Y] + g3ij[YZZ]*ang1qj[Z];
      s[ZZ] = g3ij[ZZX]*ang1qj[X] + g3ij[ZZY]*ang1qj[Y] + g3ij[ZZZ]*ang1qj[Z];
      df.x += a * (ang1qi[X]*s[XX] + ang1qi[Y]*s[XY] + ang1qi[Z]*s[XZ]);
      df.y += a * (ang1qi[X]*s[YX] + ang1qi[Y]*s[YY] + ang1qi[Z]*s[YZ]);
      df.z += a * (ang1qi[X]*s[ZX] + ang1qi[Y]*s[ZY] + ang1qi[Z]*s[ZZ]);
      MD_vec_add(f[i], df, f[i]);
      MD_vec_substract(f[j], df, f[j]);      
      g2ij += 6;  g3ij += 10;
    }
  }

  return;
}

/* similar to compute_rec_dipole_force */
void dp_compute_rec_dipole_force(struct Dipole_Poly_Tag *dp)
{
  struct standEwald_Tag *se = dp->se;
  MD_Dvec * const f = se->force;
  const MD_Double *d = dp->dipole;
  const MD_Double *q = se->charge;
  MD_Double *g2d = se->recg2d;
  MD_Double *g2q = se->recg2q;
  const MD_Int natoms = se->natoms;
  const MD_Int ng2d = 3 * natoms;
  MD_Int i;

  for (i = 0; i < ng2d; i++) g2d[i] = 0.0;
  stdEw_compute_pseudores(se, d, 0, g2d);
  compute_recG2q(se);
  for (i = 0; i < natoms; i++) {
    f[i].x += q[i]*g2d[X] + (g2q[XX]*d[X] + g2q[XY]*d[Y] + g2q[XZ]*d[Z]);
    f[i].y += q[i]*g2d[Y] + (g2q[YX]*d[X] + g2q[YY]*d[Y] + g2q[YZ]*d[Z]);
    f[i].z += q[i]*g2d[Z] + (g2q[ZX]*d[X] + g2q[ZY]*d[Y] + g2q[ZZ]*d[Z]);
    d += 3; g2d += 3; g2q += 6;
  }

  dp_add_recddforce(dp);

  return;
}


/* almost same as add_recddforce in stdEwald_rec.c */
void dp_add_recddforce(struct Dipole_Poly_Tag *dp)
{
  struct standEwald_Tag *se = dp->se;
  MD_Dvec * const f = se->force;
  const MD_Dvec * const reclatt = se->reclatt;
  const MD_Double * const recU = se->recU;
  const MD_Double *sinkr = se->sinkr;
  const MD_Double *coskr = se->coskr;
  const MD_Double *ang1q = dp->aneg_G1q;
  const MD_Double a = dp->a;
  MD_Double kx, ky, kz;
  MD_Double au, aukx, auky, aukz;
  MD_Double dxckr, dyckr, dzckr, dxskr, dyskr, dzskr;
  MD_Double tmp, sfac, cfac;
  const MD_Int natoms = se->natoms;
  const MD_Int nreclatt = se->nreclatt;
  MD_Int kreclatt, iatom;

  for (kreclatt = 0; kreclatt < nreclatt; kreclatt++) {
    dxckr = dyckr = dzckr = 0.0;
    dxskr = dyskr = dzskr = 0.0;
    ang1q = dp->aneg_G1q;
    for (iatom = 0; iatom < natoms; iatom++) {
      dxckr += ang1q[X] * coskr[iatom];
      dyckr += ang1q[Y] * coskr[iatom];
      dzckr += ang1q[Z] * coskr[iatom];
      dxskr += ang1q[X] * sinkr[iatom];
      dyskr += ang1q[Y] * sinkr[iatom];
      dzskr += ang1q[Z] * sinkr[iatom];
      ang1q += 3;
    }
    kx = reclatt[kreclatt].x;
    ky = reclatt[kreclatt].y;
    kz = reclatt[kreclatt].z;
    sfac = kx*dxckr + ky*dyckr + kz*dzckr;
    cfac = kx*dxskr + ky*dyskr + kz*dzskr;
    au = a * recU[kreclatt];
    aukx = au * kx;
    auky = au * ky;
    aukz = au * kz;
    ang1q = dp->aneg_G1q;
    for (iatom = 0; iatom < natoms; iatom++) {
      tmp = (cfac*coskr[iatom] - sfac*sinkr[iatom])
          * (ang1q[X]*kx + ang1q[Y]*ky + ang1q[Z]*kz);
      f[iatom].x += aukx * tmp;
      f[iatom].y += auky * tmp;
      f[iatom].z += aukz * tmp;
      ang1q += 3;
    }
    coskr += natoms;
    sinkr += natoms;
  }

  return;
}


void dp_compute_mol_sys_dipole(struct Dipole_Poly_Tag *dp,
			       MD_Double *mol_perm_dipole,
			       MD_Double *mol_induc_dipole,
			       MD_Double *mol_total_dipole,
			       MD_Dvec *system_total_dipole)
{
  const struct standEwald_Tag *se = dp->se;
  const MD_Double* q = se->charge;
  const MD_Dvec* pos = se->prealpos;
  const MD_Double* idipole = dp->dipole; /* induced dipole */
  const MD_Int natoms = se->natoms;
  const MD_Int nmols = natoms / 3;
  const MD_Double inv_nmols = 1.0 / (MD_Double) nmols;
  const MD_Double inv_DEBYE = 1.0 / DEBYE;
  MD_Dvec mpd, mid, mtd; /* molecule (permanent | induced | total) dipole */
  MD_Dvec sys_totd;
  MD_Double mol_permd, mol_induced, mol_totd;
  MD_Int imol, io, ih1, ih2;

  mol_permd = 0.0;
  mol_induced = 0.0;
  mol_totd = 0.0;
  mid.x = mid.y = mid.z = 0.0; /* must in case there is no dipole */
  sys_totd.x = sys_totd.y = sys_totd.z = 0.0;
  for (imol = 0; imol < nmols; imol++) {
    io = imol*3; ih1 = io+1; ih2 = io+2;
    mpd.x = q[io ] * pos[io ].x + q[ih1] * pos[ih1].x + q[ih2] * pos[ih2].x;
    mpd.y = q[io ] * pos[io ].y + q[ih1] * pos[ih1].y + q[ih2] * pos[ih2].y;
    mpd.z = q[io ] * pos[io ].z + q[ih1] * pos[ih1].z + q[ih2] * pos[ih2].z;
    mol_permd += sqrt(MD_vec_dot(mpd, mpd));
    if (se->has_induced_dipole) {
      io = imol*9; ih1 = io+3; ih2 = io+6;
      mid.x = idipole[io+X] + idipole[ih1+X] + idipole[ih2+X];
      mid.y = idipole[io+Y] + idipole[ih1+Y] + idipole[ih2+Y];
      mid.z = idipole[io+Z] + idipole[ih1+Z] + idipole[ih2+Z];
      mol_induced += sqrt(MD_vec_dot(mid, mid));
    }
    MD_vec_add(mpd, mid, mtd);
    MD_vec_add(sys_totd, mtd, sys_totd);
    mol_totd += sqrt(MD_vec_dot(mtd, mtd));
  }
  mol_permd   *= inv_nmols;
  mol_induced *= inv_nmols;
  mol_totd    *= inv_nmols;

  *mol_perm_dipole  = mol_permd * inv_DEBYE;
  *mol_induc_dipole = mol_induced * inv_DEBYE;
  *mol_total_dipole = mol_totd * inv_DEBYE;
  MD_vec_mul(sys_totd, inv_DEBYE, (*system_total_dipole));

}
