/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/*
 * canonical ensemble simulation using (implicit) Nose-Hoover algorithm.
 * the half-kick is carried out through solving a cubic equation, as is
 * suggested in the following paper:
 *
 * The Nose-Poincare Method for Constant Temperature Molecular Dynamics
 * Stephen D. Bond, Benedict J. Leimkuhler, and Brian B. Laird
 * Journal of Computational Physics, Vol 151, 114-134 (1999)
 *
 * see my note: Oct-15-2003
 *
 * problem: Nose-Hoover method does not work with Rattle yet: at the end
 * of its timestep, it is hard to satisfy the rattle constraint and
 * solve for \psi (so that the pseudo energy is conserved). In the current
 * implementation, pseudoenergy drifts downwards if rattle is used.
 * however, Nose-Hoover can work with shake, which gives wrong energies
 * at integer timesteps. The good thing about it is that the wrong energy
 * does not drift, it fluctuates around the exact value.
 *
 * Note: 
 *   MUST set the total linear momentum to zero: 
 *   see << Understanding Molecular Dynamics>> Frenkel, Smit
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#include "unit.h"
#include "constant.h"
#include "utilities.h"
#include "helper.h"
#include "shake.h"
#include "linearfit.h"
#include "avgvar.h"
#include "data_collector.h"
#include "rattle.h"
#include "settle.h"
#include "nose_hoover_dynamics.h"
#include "data.h"

#undef DH_DEBUG
#ifdef DH_DEBUG
#define FLT(x) printf("%s, %d: (%s)=%g\n", __FILE__, __LINE__, #x, (double)(x));
#else
#define FLT(x)
#endif

/* try using settle with Nose-Hoover */
#define SETTLE

#ifndef SETTLE
#define SHAKE   
#endif


/* solve cubic equation (a + b*x)^2 * x = c using Newton's method
 * with initial guess x.
 * the exact solution is not favored because there are 3 roots. 
 */
static 
MD_Double cubic_solve(MD_Double x, MD_Double a, MD_Double b, MD_Double c);




MD_Errcode nosehoover_init(struct NoseHoover_Tag *nh, 
			   struct Force_Tag *force, 
			   struct Data_Tag  *data, 
			   const MD_Double temperature /* Kelvin */ )
{
  const MD_Double tau_over_twopi = 50.0;  /* femto second */

  assert(NULL != nh);

  nh->force = force;
  nh->data = data;
  nh->temperature = temperature;

  nh->degree_freedom = (MD_Double) data->degree_freedom; 
  nh->twice_desired_kinE = 2.0 * 0.5 * nh->degree_freedom * 
    nh->temperature * KELVIN;
  FLT(KELVIN);
  FLT(nh->degree_freedom);
  FLT(nh->temperature);
  FLT(nh->twice_desired_kinE);
  FLT(KJOULE_PER_MOL);

  /* inv_Qmass: read from GROMACS manual */
  nh->inv_Qmass = 1.0 / (nh->degree_freedom * KELVIN * nh->temperature *  
			 tau_over_twopi * tau_over_twopi);  /* ???????? */
  nh->fricoeff = 0.0;      /* initial value, perturbation should be small */
  nh->log_s = 0.0;         /* I think it is a good choice for init value */
  nh->output_freq = 100;

  printf("Nose-Hoover thermostat parameters:\n");
  printf("  inv_Qmass = %f,  friction_coefficient = %f, ln(s) = %f\n",
         nh->inv_Qmass, nh->fricoeff, nh->log_s);
  printf("  degree of freedom = %f, tempreature = %f\n", 
	 nh->degree_freedom, nh->temperature);
  printf("  output frequency: %d\n", nh->output_freq);
  return OK;
}


MD_Errcode nosehoover_destroy(struct NoseHoover_Tag *nh) 
{
  memset(nh, 0, sizeof(struct NoseHoover_Tag));
  return OK;
}


MD_Errcode nosehoover_run(struct NoseHoover_Tag *nh, const MD_Int nsteps,
			  enum RUN_TYPE run_type)
{
  struct Data_Tag *data = nh->data;
  struct Force_Tag *force = nh->force;
  MD_Atom *atom = data->atom;
  MD_Dvec *vel = data->vel;
  MD_Dvec *pos = data->pos;
  MD_Dvec *f = force->f;
  const MD_Double dt = data->timestep;   /* 1 fs */
  const MD_Double half_dt = 0.5 * dt;
  const MD_Double inv_Qmass = nh->inv_Qmass;
  const MD_Double inv_Kcal_per_mol = 1.0 / KCAL_PER_MOL;
  const MD_Int natoms = data->natoms;
  const MD_Int nmols = natoms / 3;

#if 0
  const MD_Int output_freq = nh->output_freq;
  const MD_Int dump_period = 10000; 
  const MD_Int bkpoint_period = 100000; /* 100ps */
#endif

  /* DJH- for output */
  const MD_Int outputfreq = (data->outputEnergies > 0 ?
      data->outputEnergies : nsteps);
  const MD_Int restartfreq = (data->restartfreq > 0 ?
      data->restartfreq : nsteps);
  const MD_Int dcdfreq = (data->dcdfreq > 0 ?
      data->dcdfreq : nsteps);
  /* DJH- end for output */

  MD_Int istep, k;
  struct Data_Collector_Tag* data_collector = NULL;
#if defined(SETTLE)
  struct Settle_Water *tw = NULL;
#elif defined(SHAKE)
  struct Shake_Water *sw = NULL;
#else
  struct Rattle_Water *rw = NULL;
#endif
  MD_Double dielectric_constant;
  MD_Double twoke, twoke_tilde;
  MD_Double inv_mass, tmp;
  MD_Double ke;


  if (0 == nsteps) return OK;

  /* sanity check */
  assert(force->model_id >= 0 && force->model_id < NMODELS); 
  if (POL1==force->model_id || POL3 == force->model_id || 
      RPOL == force->model_id || SPC == force->model_id) {
    assert(nmols * 3 == natoms); /* water */
  }

  printf("\n\n\n-------------------------------------------------------\n");
  printf("Nose-Hoover dynamics for constant temperature:\n"
         "  %s for %d steps ...\n", 
	 (THERMALIZATION == run_type ? "thermalize": "production run"), 
	 nsteps);
  printf("  timestep = %f femto second \n", dt);
  printf("  output frequency is %d (output energy)\n", outputfreq);
  printf("  restart frequency is %d (checkpoint simulation) \n", restartfreq);
  printf("  DCD frequency is %d (save trajectory file frame)\n", dcdfreq);
  printf("-------------------------------------------------------\n");

  /*
#ifndef SHAKE
  printf("the current implementation of Nose Hoover dynamics does not "
	 " work with RATTLE\n");
  exit(1);
#endif
  */

#if !defined(SHAKE) && !defined(SETTLE)
  printf("  Note: rattle does not work with Nose-Hoover dynamics yet !!!\n");
  printf("        pseudo-energy drift downwards when using rattle\n");
  exit(1);
#endif

  if (PRODUCTION == run_type) {
    data_collector = my_malloc(sizeof(struct Data_Collector_Tag), 
			     "data collector");
    data_collector_init(data_collector, data, force, ConstTemp);
  }

  if (POL1==force->model_id || POL3 == force->model_id || 
      RPOL == force->model_id || SPC == force->model_id) {
#if defined(SETTLE)
    {
      MD_Double ohdist = data->bond_len;
      MD_Double hhdist = sqrt(2.0*ohdist*ohdist*(1.0-cos(data->bond_angle)));
      tw = my_malloc(sizeof(struct Settle_Water), "settle");
      settle_init(tw, O_MASS, H_MASS, hhdist, ohdist, natoms);
    }
#elif defined(SHAKE)
    sw = my_calloc((size_t)1, sizeof(struct Shake_Water), "shake");
    shake_init(sw, data->bond_len, data->bond_angle, O_MASS, H_MASS, 
	       natoms);
#else
    rw = my_calloc((size_t)1, sizeof(struct Rattle_Water), "rattle");
    rattle_init(rw, data->bond_len, data->bond_angle, O_MASS, H_MASS,
                natoms);
#endif
  } 

  if (force_compute(force)) return MD_FAIL;

  compute_KETE(data);
  data_print_energy(data, data->firststepnum); 
  nosehoover_computeH(nh);
  printf("pseudoE: %7d %15.9f\n", 
	  data->firststepnum, nh->conserved_psuedoE * inv_Kcal_per_mol);
  if (!conserve_linear_momentum(data)) return MD_FAIL;

  /* force_dump(force); */ 
  if (NULL != data_collector) {
    data_collector_update(data_collector, data->firststepnum);
  }

  /*
   * DJH- output dcd file header
   */
  if (output_dcd_header(data, run_type)) {
    fprintf(stderr, "cannot write dcd file header\n"); /* soft fail */
  }

  /* 
   *start the MD run, leapfrog method + Nose-Hoover algorithm
   */
  for (istep = data->firststepnum + 1; istep <= data->firststepnum + nsteps; 
       istep++) {  
    /* fprintf(stderr, "half kick\n"); */
    /*
    printf("ke before 1st kick: %20.15f\n",
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */
    tmp = nh->fricoeff;
    ke = compute_KE(vel, atom, natoms);
    for (k = 0;  k < natoms;  k++) {
      inv_mass = 1.0 / atom[k].m;
      vel[k].x += half_dt * (f[k].x * inv_mass - tmp * vel[k].x);
      vel[k].y += half_dt * (f[k].y * inv_mass - tmp * vel[k].y);
      vel[k].z += half_dt * (f[k].z * inv_mass - tmp * vel[k].z);
    }
    nh->fricoeff += half_dt * inv_Qmass * (ke * 2.0 - 
					   nh->twice_desired_kinE);
    /*
    printf("ke after 1st kick: %20.15f\n",
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */						  
    /* fprintf(stderr, "drift\n"); */
#if defined(SETTLE)
    if (NULL != tw) settle_prepare(tw, pos);
#elif defined(SHAKE)
    if (NULL != sw) shake_prepare(sw, pos);
#else
    if (NULL != rw) rattle_prepare(rw, pos);
#endif

    /* fprintf(stderr, "drift\n"); */
    for (k = 0;  k < natoms;  k++) {
      MD_vec_mul_add(pos[k], vel[k], dt);
    }
    nh->log_s += nh->fricoeff * dt;

#if defined(SETTLE)
    if (NULL != tw) {
      for (k = 0; k < nmols; k++) {
	if (settle1(tw, 3*k, pos+3*k, vel+3*k, dt)) {
	  return MD_FAIL;
	}   
      }   
    }
#elif defined(SHAKE)
    if (NULL != sw) {
      for (k = 0; k < nmols; k++) {
        if (shake(sw, 3*k, pos+3*k, vel+3*k, dt)) {
          return MD_FAIL;
        }
      }
    }
#else
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
        if (rattle1(rw, 3*k, pos+3*k, vel+3*k, dt)) {
          return MD_FAIL;
        }
      }
    }
#endif
    /*
    printf("ke after shake  : %20.15f\n",
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */

    if (force_compute(force))  return MD_FAIL;
    /* force_output(force); */
    /* fprintf(stderr, "half kick\n"); */

    twoke_tilde = 0.0;
    for (k = 0; k < natoms; k++) {
      tmp = half_dt / atom[k].m;
      MD_vec_mul_add(vel[k], f[k], tmp);
      twoke_tilde += atom[k].m * MD_vec_dot(vel[k], vel[k]);
    }
    nh->fricoeff -= half_dt * inv_Qmass * nh->twice_desired_kinE;
    twoke = cubic_solve(twoke_tilde, 1.0 + half_dt * nh->fricoeff, 
			half_dt * half_dt * inv_Qmass, twoke_tilde); 
    nh->fricoeff += half_dt * inv_Qmass * twoke;
    tmp = 1.0 / (1.0 + half_dt * nh->fricoeff);
    for (k = 0; k < natoms; k++) {
      MD_vec_mul(vel[k], tmp, vel[k]);
    } 

#if defined(SETTLE)
    if (NULL != tw) {
      for (k = 0; k < nmols; k++) {
	if (settle2(tw, pos+3*k, vel+3*k, dt)) {
	  fprintf(stderr, "settle2 failed for molecule %d\n", k); 
	  return MD_FAIL;
	}
      }
    }
#elif !defined(SHAKE)
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
        if (rattle2(rw, pos+3*k, vel+3*k, dt)) {
	  fprintf(stderr, "rattle2 failed for molecule %d\n", k); 
	  return MD_FAIL;
	}
      }
    }
#endif
    
    /*
     * DJH- check for energy output
     */
    if (istep % outputfreq == 0) {
      compute_KETE(data); /* must compute ke at every step */
      data_print_energy(data, istep); 
      nosehoover_computeH(nh);
      printf("pseudoE: %7d %15.9f\n", 
	     istep, nh->conserved_psuedoE * inv_Kcal_per_mol);
      fflush(stdout);
      if (!conserve_linear_momentum(data)) return MD_FAIL;
    }

    if (NULL != data_collector) data_collector_update(data_collector, istep);

    /* DJH- check for dcd file output
     *
     * data collection output corresponds with DCD frames
     */
    if ((istep - data->firststepnum) % dcdfreq == 0) {
      if (NULL != data_collector) {  
	data_collector_output(data_collector);
	if (POL1==force->model_id || POL3==force->model_id || 
	    RPOL==force->model_id || SPC==force->model_id) {
	  MD_Double msqr = avgvar_get_avg(data_collector->avgvar[
					data_collector->av_index[AV_Msqr]]);
	  dielectric_constant = 1.0 + 4.0*Pi/(3.0*force_compute_volume(force))
	    * (msqr*DEBYE*DEBYE) / (nh->temperature*KELVIN) * COULOMB_SQR;
	  printf("dielectric constant: %f\n", dielectric_constant);
	}
        fflush(stdout);
      }
      if (output_dcd_frame(data, istep)) {
        fprintf(stderr, "cannot write dcd frame\n");  /* soft fail */
      }
#if 0
      /* for computation of diffusion constant */
      if (sprintf(filename, "output/pos_%d.dat", istep) + 1 >
          (MD_Int)sizeof(MD_String)) {
        fprintf(stdout, "filename is too small\n");
        return MD_FAIL;
      }
      if (bindump_vec_array(pos, natoms, filename)) {
        fprintf(stdout, "cannot output %s, disk full ?? \n", filename);
      }      
#endif
    }

    /*
     * DJH- check for restart file output
     */
    if ((istep - data->firststepnum)  % restartfreq == 0) {
      if (data_bindump_image(data, istep)) {
	fprintf(stderr, "cannot dump system image\n");
      }
    }

#if 0
    /* ------- dump velocity at every timestep (to compute VAF) ------- */
    if (istep % 1 == 0) {
      if (sprintf(filename, "output/vel_%d.dat", istep) + 1 >
          (MD_Int)sizeof(MD_String)) {
        fprintf(stdout, "filename is too small\n");
        return MD_FAIL;
      }
      if (bindump_vec_array(data->vel, data->natoms, filename)) {
        fprintf(stdout, "error output velocity\n");
      }
    }	
#endif

    /* fflush(NULL); */
  } /* end of istep */

  /*
   * DJH- done writing dcd file
   */
  if (output_dcd_done(data)) {
    fprintf(stderr, "cannot finish with dcd file writing\n"); /* soft fail */
  }

  data->firststepnum += nsteps;

#if 0
  /* (this is already done by data collection output above) */
  if ( (POL1==force->model_id || POL3==force->model_id ||
	RPOL==force->model_id || SPC==force->model_id) && 
       (NULL != data_collector)) {
    MD_Double msqr = avgvar_get_avg(data_collector->avgvar[
                                    data_collector->av_index[AV_Msqr]]);
    dielectric_constant = 1.0 + 4.0*Pi/(3.0*force_compute_volume(force)) *
      (msqr * (DEBYE * DEBYE)) /  (nh->temperature * KELVIN) * COULOMB_SQR;
    printf("dielectric constant: %f\n", dielectric_constant);
    fflush(stdout);
  }
#endif

  if (NULL != data_collector) {
    data_collector_destroy(data_collector);
    free(data_collector);
  }

#if defined(SETTLE)
  if (NULL != tw) {
    settle_destroy(tw);
    free(tw);
  }
#elif defined(SHAKE)
  if (NULL != sw) {
    shake_destroy(sw);
    free(sw);
  }
#else
  if (NULL != rw) {
    rattle_destroy(rw);
    free(rw);
  }
#endif

  return OK;
}

void nosehoover_computeH(struct NoseHoover_Tag* nh)
{
  const struct Data_Tag *data = nh->data;
  nh->conserved_psuedoE = data->energy[TE] +
    0.5 * nh->inv_Qmass * nh->fricoeff * nh->fricoeff +
    nh->twice_desired_kinE * nh->log_s;
}



/* solve cubic equation (a + b*x)^2 * x = c using Newton's method
 * with initial guess x.
 * the exact solution is not favored because there are 3 roots. 
 */
MD_Double cubic_solve(MD_Double x, MD_Double a, MD_Double b, MD_Double c)
{
  MD_Double f, df;
  MD_Double dx;
  MD_Double tmp;
  static const MD_Double ErrTol = 1e-10;
  static const MD_Int MaxIter = 20;
  MD_Int iter = 0;

  do {
    iter ++;
    tmp = a + b*x;
    f = x*tmp*tmp - c;
    df = tmp * tmp + 2.0*b*tmp*x;
    dx = f/df;
    x -= dx;
    if (iter > MaxIter) {
      fprintf(stderr, "implicit nose-hoover cubic solver does not converge\n");
      fprintf(stderr, "x=%f, f=%f\n", x, x*tmp*tmp - c);
      break;
    }
  } while (fabs(dx)/(1.0 + fabs(x)) > ErrTol && iter < MaxIter);

  return x;
}


MD_Double nosehoover_get_pv_over_nkT(const struct  NoseHoover_Tag*nh)
{
  if (NULL == nh) {
    fprintf(stderr, "wrong function call !\n");
    return 0.0;
  } else {
    const MD_Dvec *r = nh->data->pos; /* position */
    const MD_Dvec *f = nh->force->f;   /* force */
    MD_Double virial;
    const MD_Int natoms = nh->data->natoms;
    MD_Int i;

    virial = 0.0;
    for (i = 0; i < natoms; i++) {
      virial += MD_vec_dot(r[i], f[i]);
    }
    virial /= 3.0;

    return 1.0 + virial / ((MD_Double)(natoms * 3) * KELVIN *
                           nh->temperature);
  }
}
