/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */


#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include "constant.h"
#include "unit.h"
#include "helper.h"
#include "pme.h"
#include "pme_direct.h"
#include "pme_recip.h"
#include "pme_utilities.h"

double erfc(double x);

#ifdef TIMING_PME
#include "timer.h"
static MD_Double Tdirsetup = 0.0;
static MD_Double Trecsetup = 0.0;
static MD_Double Trec = 0.0;
static MD_Double Tdir = 0.0;
static MD_Double Tsd = 0.0; /* time cost for solving dipole */
static MD_Double Ttotal = 0.0;
static MD_Int Tcounter = 0;
#endif

static MD_Double calc_ewaldcof(MD_Double rc, MD_Double errTol);

MD_Errcode pme_init(struct Pme_Tag* pme, const struct PmeParams_Tag* params)
{
  MD_Int i, natoms;
  struct PmeParams_Tag* pmeparams; 

  assert(NULL != pme);
  assert(NULL != params);
  assert(0 == params->cellOrigin.x && 0 == params->cellOrigin.y &&
         0 == params->cellOrigin.z);
  assert(0< params->cellBasisVector1.x && 0< params->cellBasisVector2.y &&
         0< params->cellBasisVector3.z && 
         0==params->cellBasisVector1.y && 0==params->cellBasisVector1.z &&
         0==params->cellBasisVector2.x && 0==params->cellBasisVector2.z &&
         0==params->cellBasisVector3.x && 0==params->cellBasisVector3.y);

  pme->savePmeParams = *params;
  pmeparams = &(pme->savePmeParams);
  pmeparams->ewaldcof = calc_ewaldcof(pmeparams->cutoff, 
				      pmeparams->PMETolerance);
  natoms = pmeparams->natoms;
  pme->force = my_calloc((size_t)natoms, sizeof(MD_Dvec), "toal force");  
  pme->charge= my_calloc((size_t)natoms, sizeof(MD_Double), "charge");
  for (i=0; i<natoms; i++) pme->charge[i] = pmeparams->patom[i].q;
  pme->qq_self_potential = - pmeparams->ewaldcof * one_over_sqrtPi * 
    DOT(pme->charge, pme->charge, pmeparams->natoms);

  printf("PME module:\n");
  printf("      error tolerance = %g\n", pmeparams->PMETolerance);
  printf("             ewaldcof = %20.15f\n", pmeparams->ewaldcof);
  printf("                 rcut = %20.15f\n", pmeparams->cutoff);
  printf("            grid size = %d, %d, %d\n", 
	 pmeparams->PMEGridSizeX, pmeparams->PMEGridSizeY, 
	 pmeparams->PMEGridSizeZ);
  printf("  interpolation order = %d\n", pmeparams->PMEInterpOrder);

  pme->dir = my_calloc((size_t)1, sizeof(*pme->dir), "direct sum");
  if (pmedir_init(pme->dir, pmeparams)) {
    fprintf(stderr, "failed to init direct module of pme\n");
    return MD_FAIL;
  }

  fprintf(stderr, "direct sum is inited\n");

  pme->rec = my_calloc((size_t)1, sizeof(*pme->rec),  "reciprocal sum");
  if (pmerec_init(pme->rec, pmeparams)) {
    fprintf(stderr, "failed to init reciprocal module of pme\n");
    return MD_FAIL;
  }

  fprintf(stderr, "reciprocal sum is inited\n");

#ifndef DIPOLE_POLY
  if (pmeparams->has_induced_dipole) {
    if (pme_setup_dsolver(pme, pmeparams->dsolver_param) ) {
      fprintf(stderr, "failed to set up dipole solver\n");
      return MD_FAIL;
    }
  }
#endif

  return OK;
}


MD_Errcode pme_destroy(struct Pme_Tag *pme) 
{
#ifdef TIMING_PME
  printf("Tdirsetup=%f, Trecsetup=%f, Tdir=%f, Trec=%f, Tsd=%f, Ttotal=%f\n", 
	 Tdirsetup, Trecsetup, Tdir,Trec, Tsd, Ttotal);
  printf("avg: dir=%f, rec=%f, sd=%f, total=%f\n", Tdir/Tcounter, 
	 Trec/Tcounter, Tsd/Tcounter, Ttotal/Tcounter);
#endif

  free(pme->force);
  free(pme->charge);

  if (pmedir_destroy(pme->dir)) {
    fprintf(stderr, "failed to destroy direct module of pme\n");
    return MD_FAIL;
  }
  free(pme->dir);
  if (pmerec_destroy(pme->rec)) {
    fprintf(stderr, "failed to destroy reciprocal sum\n");
    return MD_FAIL;
  }
  free(pme->rec);

  if (NULL != pme->dsolver) {
    if (dsolver_destroy(pme->dsolver)) {
      return MD_FAIL;
    }
    free(pme->dsolver);
  }

  memset(pme, 0, sizeof *pme);
  return OK;
}


MD_Double pme_get_energy(const struct Pme_Tag *pme) 
{
  return pme->potential;
}


const MD_Dvec *pme_get_force(const struct Pme_Tag *pme)
{
  return pme->force;
}


void pme_fill_diagonal(const struct Pme_Tag *pme, MD_Double *d)
{
  const struct PmeParams_Tag *params = &(pme->savePmeParams);
  const MD_Double b = params->ewaldcof;
  const MD_Double c = 4.0*b*b*b*one_over_sqrtPi / 3.0;  
  const MD_Double *polar = params->polarizability;
  const MD_Int n = params->natoms * 3;
  MD_Int i;

  for (i = 0; i < n; i++)  d[i] = 1.0/polar[i] - c;
  return;
}


MD_Errcode pme_setup_dsolver(struct Pme_Tag *pme,
			     struct Dsolver_Parameters_Type dsolver_param)
{
  struct Dsolver_Init_Tag params;
  const MD_Int natoms = pme->savePmeParams.natoms;
 
  pme->dsolver = my_calloc((size_t)1, sizeof(struct Dsolver_Tag), "dsolver");

  memset(&params, 0, sizeof(struct Dsolver_Init_Tag));
  params.electro = pme;
  params.ewaldmethod = ES_SPME;
  params.mat_vec_mul_mod = pme;
  params.compute_pseudores = (compute_pseudores_type) pme_compute_pseudores;
  params.matrixsize = 3 * natoms;
  params.specified_param = dsolver_param;
  params.specified_param.errTol2 *= DEBYE*DEBYE/3.0; /* rectify */
  params.density = (MD_Double) natoms / lattice_volume(&pme->rec->lattice);
  params.neibrlist = pme->dir->neibrlist;
  params.numneibrs = pme->dir->numneibrs;

  if (dsolver_init(pme->dsolver, &params)) {
    fprintf(stderr, "failed to init dsolver module\n");
    return MD_FAIL;
  }
 
  return OK;
}


MD_Errcode pme_compute(struct Pme_Tag *pme)
{
  struct PmeParams_Tag params = pme->savePmeParams;
#ifdef TIMING_PME
  MD_Double tstart, tend;
  MD_Double t0 = time_of_day();
#endif

  /*
  params.has_induced_dipole = 0;
  */

  if (!params.has_induced_dipole) {
#ifdef TIMING_PME
    tstart = time_of_day();
#endif
    
    pmedir_charge_compute(pme->dir, &params, pme->charge, params.ppos);
    
#ifdef TIMING_PME
    tend = time_of_day();  Tdir += tend - tstart; tstart = tend;
#endif
    pmerec_charge_compute(pme->rec, params.ppos, pme->charge);
#ifdef TIMING_PME
    tend = time_of_day();  Trec += tend - tstart; tstart = tend;
#endif
  } else {
    MD_Double *dipole = NULL;
#ifdef TIMING_PME
    tstart = time_of_day();
#endif
    
    pmedir_dipole_setup(pme->dir, &params, pme->charge, params.ppos);

#ifdef TIMING_PME
    tend=time_of_day(); Tdirsetup+=tend-tstart; Tdir+=tend-tstart; tstart=tend;
#endif
    pmerec_dipole_setup(pme->rec, params.ppos, pme->charge);
#ifdef TIMING_PME
    tend=time_of_day();  Trecsetup+=tend-tstart; Trec+=tend-tstart; tstart=tend;
#endif
#if 1
    if (dsolver_solve(pme->dsolver)) {
      if (dsolver_dump_dipole(pme->dsolver, "emergency_dipoles")) {
	fprintf(stderr, "cannot dump dipoles\n");
	/* not hard failure */
      }
    }
#ifdef TIMING_PME
    tend = time_of_day();  Tsd += tend - tstart; tstart = tend;
#endif
    dipole = dsolver_get_dipole(pme->dsolver);
#else
    /* debug: add it here, so we can assume getting the same dipole as the
    * standard Ewald method */
    dipole = my_calloc((size_t)3*natoms, sizeof(*dipole), "dip");
    binread_array(dipole, 3*params.natoms, "dip");
#endif

#ifdef DEBUG_PME
    {
      printf("<d,d>=%20.15f\n", DOT(dipole,dipole,3*params.natoms));
    }
#endif

    pmedir_dipole_force(pme->dir, params.ppos, pme->charge, dipole);
#ifdef TIMING_PME
    tend = time_of_day();  Tdir += tend - tstart; tstart = tend;
#endif
    /*
    memset(dipole, 0, params.natoms*3*sizeof(MD_Double));
    dipole[0]=1.0; dipole[1]=2.0; dipole[2]=3.0;
    memset(pme->charge, 0, params.natoms*sizeof(MD_Double));
    */
    {
      MD_Int calc_d_arr = (JCG_X != pme->dsolver->method) ||
	(CG != pme->dsolver->method) || (JCG_R != pme->dsolver->method);
      calc_d_arr = 1;
      pmerec_dipole_calc(pme->rec, pme->charge, dipole, calc_d_arr);
    }
#ifdef TIMING_PME
    tend = time_of_day();  Trec += tend - tstart; tstart = tend;
#endif    
  }

#ifdef DEBUG_PME
  printf("potential: dir=%20.15f, rec=%20.15f\n", pme->dir->potential, 
	 pme->rec->potential);
#endif

  {
    const struct PmeDirect_Tag *dir = pme->dir;	
    const struct PmeRecip_Tag *rec = pme->rec;	
    const MD_Dvec *df = dir->force;	
    const MD_Dvec *rf = rec->force;	
    MD_Dvec *tf = pme->force;		
    const MD_Double c = COULOMB_SQR;	
    MD_Int natoms = pme->savePmeParams.natoms;
    MD_Int i;				
    for (i=0;  i<natoms;  i++) {	
      MD_vec_add_mul(df[i], rf[i], c, tf[i]);
    }					
    pme->potential = (pme->qq_self_potential + dir->potential +	
                      rec->potential) * c;		
  }

#ifdef TIMING_PME
  Ttotal += time_of_day() - t0; 
  Tcounter ++;
/*
  printf("Tdirsetup=%f, Trecsetup=%f, Tdir=%f, Trec=%f, Tsd=%f, Ttotal=%f\n", 
	 Tdirsetup, Trecsetup, Tdir,Trec, Tsd, Ttotal);
*/
#endif

  return OK;
}


/* compute -G1q - G2d  */
void pme_compute_pseudores(const struct Pme_Tag* pme, const MD_Double *d,
			   const MD_Int flag, MD_Double *pseudores)
{
  pmedir_compute_pseudores(pme->dir, d, flag, pseudores);
  pmerec_add_pseudores(pme->rec, d, flag, pseudores); 
}


MD_Double calc_ewaldcof(MD_Double rc, MD_Double errTol)
{
  MD_Double beta = 8.0, low = 0.0, high;
  MD_Int i;
 
  while (erfc(beta*rc) > errTol)  beta += beta;
  high = beta;
  for (i = 0; i < 100; i++) {
    beta = 0.5 * (high+low);
    if (erfc(beta*rc) > errTol) low=beta;
    else high=beta;
  }
 
  return beta;
}

/*
const MD_Dvec* pme_get_force(const struct Pme_Tag *pme)
{
  return pme->force;
}

  
MD_Double pme_get_potential(const struct Pme_Tag *pme)
{
  return pme->potential;
}
*/

MD_Double* pme_get_dipole(const struct Pme_Tag *pme) 
{
  if (NULL == pme || NULL == pme->dsolver) return NULL;
  return dsolver_get_dipole(pme->dsolver);
}


MD_Errcode pme_dump_dipole(const struct Pme_Tag *pme, const char* filename)
{
  if (NULL != pme->dsolver) {
    return dsolver_dump_dipole(pme->dsolver, filename);
  } else {
    return MD_FAIL;
  }
}
