/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/* standard implementation of the Ewald sum */


#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>
#include <assert.h>
#include "timer.h"
#include "unit.h"
#include "constant.h"
#include "utilities.h"
#include "helper.h"
#include "standEwald.h"
#include "standEwald_dir.h"
#include "standEwald_rec.h"
#include "data.h"
#include "force.h"


/* calling sequence: **************************************************
  stdEwld_init:
    + a1, a2, a3, volume
    + rcut, beta, kcut
    + allocate general-purpose array: force, charge, neg_G1q
    + init_dir
    + init_rec
    + setup_dsolver
  stdEwld_compute  <-- used in every step of molecular dynamics integration
    + dir_setup() 
    + rec_setup() <-- must be called after dir_setup
    + dsolver_solve: solve the dipole equation.
    + compute_dir_dipole_force()
    + compute_rec_dipole_force()
    + overall energy and force
 ***********************************************************************/

/*
#define EXCLUDE_REC
#define EXCLUDE_DIR 
#define DEBUG_G
#define EXACT_EWALD
#define TIMING
#define DEBUG_D 
*/

#ifdef TIMING
static MD_Double Trec = 0.0;
static MD_Double Tdir = 0.0;
static MD_Double Tsd = 0.0; /* time cost for solving dipole */
static MD_Double Ttotal = 0.0;
#endif


/***********************************************************************/
/*********************** implementations *******************************/
/***********************************************************************/


MD_Errcode stdEw_init(struct standEwald_Tag *se, 
		       struct standEwald_init_Tag *init_data)
{
  MD_Int i;
  MD_Int natoms;

  assert(NULL != se && NULL != init_data);

  se->palpha             = init_data->polarizability;
  se->natoms             = init_data->natoms;
  se->ppos               = init_data->wrapped_pos;
  se->prealpos           = init_data->realpos;
  se->pexcllist          = init_data->pexcllist;
  se->systemsize         = init_data->systemsize;
  se->errTol             = init_data->errTol;
  se->has_induced_dipole = init_data->has_induced_dipole;

  /* cubic box only, for now */
  se->a1.x = init_data->systemsize.x;
  se->a2.y = init_data->systemsize.y;
  se->a3.z = init_data->systemsize.z;
  se->a1.y = se->a1.z = 0.0;
  se->a2.x = se->a2.z = 0.0;
  se->a3.x = se->a3.y = 0.0;

  /* w/o fabs, volume can be negative */
  se->volume = fabs(se->systemsize.x * se->systemsize.y * 
                    se->systemsize.z);

  se->rcut = 0.5 * sqrt(MD_vec_dot(se->systemsize, 
					se->systemsize) / 3.0);
  se->rcut2 = se->rcut * se->rcut;
  se->beta = calc_beta(se->rcut, se->errTol);
  se->kcut = calc_kcut(se->beta, se->errTol);  
  printf("Ewald sum module:\n");
  printf("  error tolerance = %g, beta = %20.15f\n", se->errTol, 
	 se->beta);
  printf("  rcut = %20.15f, kcut = %20.15f\n", se->rcut, se->kcut);

  natoms = se->natoms;
  se->force   = my_calloc((size_t) natoms, sizeof(MD_Dvec), "force");
  se->charge  = my_calloc((size_t) natoms, sizeof(MD_Double), "charge");

  if (NULL != init_data->patom) {
    for (i=0; i<se->natoms; i++) se->charge[i] = init_data->patom[i].q;
    se->qq = DOT(se->charge, se->charge, se->natoms);
    se->self_energy_qq = - se->beta * one_over_sqrtPi * se->qq;
    printf("  charge-charge self_energy:%20.15f\n", se->self_energy_qq);
  } else {
    printf(" need to assign charge and self-energy seperately\n");
  }

  if (init_dir(se)) {
    fprintf(stderr, "cannot init dir sum part\n");
    return MD_FAIL;
  }

  if (init_rec(se)) {
    fprintf(stderr, "cannot init reciprocal lattice\n");
    return MD_FAIL;
  }

  if (se->has_induced_dipole) {
    if (stdEw_setup_dsolver(se, init_data->dsolver_param) ) {
      fprintf(stderr, "failed to set up dipole solver\n");
      return MD_FAIL;
    }
  }

  return OK;
}


MD_Errcode stdEw_destroy(struct standEwald_Tag *se)
{
  free(se->force);
  free(se->charge);

  destroy_dir(se);
  destroy_rec(se);

  if (NULL != se->dsolver) {
    if (dsolver_destroy(se->dsolver)) {
      return FAILURE;
    }
    free(se->dsolver);
  }

  memset(se, 0, sizeof(struct standEwald_Tag));

  return OK;
}


MD_Errcode stdEw_compute(struct standEwald_Tag *se)
{
#ifdef TIMING 
  MD_Double toverallstart, tstart, tend;
#endif
  MD_Int natoms = se->natoms; 
  MD_Int i;

#ifdef TIMING 
  toverallstart = tstart = time_of_day();
#endif

  /* should be called before rec_setup */
  if (se->has_induced_dipole) dipole_dir_setup(se); 
  else charge_dir_setup(se);

#ifdef TIMING
  tend = time_of_day();
  Tdir += tend - tstart;
  tstart = tend;
#endif

  /* should be called after dir_setup, because the contribution from the
     recirpocal sum to the -G1q vector is added up, not cleared.*/
  if (se->has_induced_dipole) dipole_rec_setup(se);
  else  charge_rec_setup(se);

#ifdef TIMING
  Trec += time_of_day() - tstart;
#endif

  if (se->has_induced_dipole) {
    /*printf("solve d\n");*/
#ifdef TIMING
    tstart = time_of_day();
#endif

    if (dsolver_solve(se->dsolver)) {
      if (vec_buffer_bindump(se->dsolver->predictor->old_vectors,
		             "emergency_dipoles")) {
	fprintf(stderr, "cannot dump dipoles\n");
      }
    } 
    /*
    bindump_array(dsolver_get_dipole(se->dsolver), 3*natoms, "dip");

    printf("<d,d>=%20.15f\n", DOT(dsolver_get_dipole(se->dsolver),
				  dsolver_get_dipole(se->dsolver),
				  3 * natoms));
    */
#ifdef TIMING
    Tsd += time_of_day() - tstart;
#endif 
  }


#if 0
  if (se->has_induced_dipole) {
    MD_Int nmols = natoms/3;  
    MD_Double dh1=0.0, dh2=0.0, d0=0.0; 
    MD_Double *d = dsolver_get_dipole(se->dsolver); 
    printf("sqrt(<d,d>/natoms) = %g DEBYE\n", 
	   sqrt(DOT(d, d, 3*natoms)/se->natoms) / DEBYE );
    for (i = 0; i < nmols; i++) {
      d0  += d[0]*d[0] + d[1]*d[1] + d[2]*d[2];
      dh1 += d[3]*d[3] + d[4]*d[4] + d[5]*d[5];
      dh2 += d[6]*d[6] + d[7]*d[7] + d[8]*d[8];
      d += 9;
    }
    printf("<dO>=%f, <dh1>=%f, <dh2>=%f (DEBYE)\n", sqrt(d0/nmols)/DEBYE, 
	   sqrt(dh1/nmols)/DEBYE, sqrt(dh2/nmols)/DEBYE);
  }
#endif


#ifdef TIMING
  tstart = time_of_day();
#endif
  /*printf("compute electrostatic force\n");*/

  /*
  bindump_vec_array(se->force, natoms, "fcc");
  */

  if (se->has_induced_dipole) {
#ifdef TIMING
    tstart = time_of_day();
#endif
    compute_dir_dipole_force(se);
    /*
    {
      MD_Dvec *df = se->dirforce;
      MD_Dvec *rf = se->recforce;
      printf("dir+cc_rec force\n");
      for(i=0; i<natoms; i++) {
	printf("%d %f %f %f\n", i, df[i].x + rf[i].x, df[i].y+rf[i].y,
	       df[i].z+rf[i].z);
      }
    }
    */
    /* bindump_vec_array(se->force, natoms, "f2"); */
#ifdef TIMING
    tend = time_of_day(); Tdir += tend - tstart; tstart = tend;
#endif
    
    compute_rec_dipole_force(se);
    
    /* bindump_vec_array(se->force, natoms, "felec"); */
#ifdef TIMING
    Trec += time_of_day() - tstart;
#endif
  }

#ifdef DEBUG_STANDEWALD
  printf("dirE=%20.15f, recE=%20.15f\n", se->dirEnergy, se->recEnergy); 
#endif

  {
    MD_Dvec *df = se->dirforce;
    MD_Dvec *rf = se->recforce;
    MD_Dvec *f = se->force;
    const MD_Double c = COULOMB_SQR;
    /*
    printf("c=%f\n", c);
    output_vec_array(df, natoms, "direct force");
    output_vec_array(rf, natoms, "reciprocal force");
    */
#ifdef DEBUG_STANDEWALD
    {MD_Dvec sum_df={0,0,0}, sum_rf={0,0,0};
    for(i=0;i<natoms;i++) MD_vec_add(sum_df, df[i], sum_df);
    for(i=0;i<natoms;i++) MD_vec_add(sum_rf, rf[i], sum_rf);
    printf("sum of direct force: %g,%g,%g\n", sum_df.x, sum_df.y, sum_df.z);
    printf("sum of recip. force: %g,%g,%g\n", sum_rf.x, sum_rf.y, sum_rf.z);
    }
#endif
    for (i=0; i<natoms; i++) MD_vec_add_mul(df[i], rf[i], c, f[i]);
    /*
    output_vec_array(f, natoms, "total force");
    */
    se->energy = (se->dirEnergy + se->recEnergy + se->self_energy_qq) * c;
  }  

#ifdef TIMING
  tend = time_of_day(); Trec += tend-tstart; Ttotal += tend-toverallstart;
  printf("time cost: dir: %f; rec %f; solve d %f, overall: %f\n", 
          Tdir, Trec, Tsd, Ttotal);
#endif

  return OK;
}


MD_Double stdEw_get_energy(const struct standEwald_Tag* se)
{
  return se->energy;
}


const MD_Dvec*  stdEw_get_force(const struct standEwald_Tag* se)
{
  return se->force;
}
  

MD_Double* stdEw_get_dipole(const struct standEwald_Tag* se) 
{
  if (NULL != se->dsolver) {
    return dsolver_get_dipole(se->dsolver);
  } else {
    fprintf(stdout, "no induced dipole, should not be called \n");
    return NULL;
  }
}


/*=======================================================================*/


void stdEw_fill_diagonal(const struct standEwald_Tag* se, MD_Double *d)
{
  const MD_Double beta = se->beta;
  const MD_Double c = 4.0 * beta * beta * beta * one_over_sqrtPi / 3.0;  
  const MD_Double *alpha = se->palpha;
  const MD_Int n = se->natoms * 3;
  MD_Int i;

  for (i=0; i < n; i++) d[i] = 1.0/alpha[i] - c;  
  
  return;
}


MD_Errcode stdEw_setup_dsolver(struct standEwald_Tag *se, 
				 struct Dsolver_Parameters_Type dsolver_param) 
{
  struct Dsolver_Init_Tag init_data;
  const MD_Int natoms = se->natoms;

  se->dsolver = my_calloc((size_t)1, sizeof(struct Dsolver_Tag), 
			       "dsolver");
  /* the dipole equation is 
   * (G2 + Alpha^{-1} - 4beta^3/(3*sqrt(pi))) d = - G1*q
   * where d is the total (permanent + induced) dipole
   */

  memset(&init_data, 0, sizeof(init_data));
  init_data.electro = se;
  init_data.ewaldmethod = ES_StandardEwald;
  init_data.mat_vec_mul_mod = se;
  init_data.compute_pseudores = 
    (compute_pseudores_type) stdEw_compute_pseudores;
  init_data.matrixsize = 3 * natoms;
  init_data.specified_param = dsolver_param;
  init_data.specified_param.errTol2 *= DEBYE*DEBYE/3.0 ; /* rectify */
  init_data.density = se->natoms / se->volume;
  init_data.neibrlist = se->neibrlist;
  init_data.numneibrs = se->numneibrs;
  if (dsolver_init(se->dsolver, &init_data)) {
    fprintf(stderr, "failed to init electrostatic module\n");
    return MD_FAIL;
  }

  return OK;
}


void stdEw_compute_pseudores(struct standEwald_Tag* se,
			       const MD_Double *d, const MD_Int flag, 
			       MD_Double *pseudores)
{
#ifdef TIMING
  MD_Double tstart;
#endif

  ASSERT(0 != d - pseudores);

#ifdef TIMING
  tstart = time_of_day();
#endif

  compute_dir_pseudores(se, d, flag, pseudores);

  /* compute reciprocal part */
#ifdef TIMING
  Tdir += time_of_day() - tstart;
  tstart = time_of_day();
#endif
  add_rec_pseudores(se, d, flag, pseudores);

#ifdef TIMING
  Trec += time_of_day() - tstart;
#endif

}


MD_Errcode stdEw_dump_dipole(const struct standEwald_Tag *se, 
			     const char* filename)
{
  if (NULL != se->dsolver) {
    return dsolver_dump_dipole(se->dsolver, filename);
  } else {
    return MD_FAIL;
  }
}


