/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/*
 * Berendsen weak-coupling algorithm for constant temperature simulation
 * the implementation is based on 
 * Molecular dynamics with coupling to an external bath
 * J. Chem. Phys. Vol 81 3684--3690
 *  H. J. C. Berendsen, J. P. M. Postma, W. F. van Gunsteren, A. Dinola
 *  and J. R. Haak
 * the algorithm is described in page 3687. Note that
 *  1. the current code has no pressure-related computation.
 *  2. velocity is defined as v(n+1/2) = (x(n+1)-x(n))/delta_t
 *     where x(n) are the positions after shake.
 *
 * The algorithm is also mentioned in page 163 of the book
 *   "Understanding Molecular Simulation From Algorithms to Applications"
 *   Daan Frenkel & Berend Smit   2nd Ed. 
 * 
 * the MD process is carried out like this:
 * for n = 1.... {  [each MD step]
 *     scale_factor = sqrt(1+dt/tau_t (T0/T(n-1/2) - 1));
 *     v(n+1/2) = scale_factor * ( v(n-1/2) + F(n)*deltat/M )
 *     x(n+1)   = x(n) + v(n+1/2) * deltat
 *     shake
 *     output;  
 * }  
 * Question: it is not clear, however, how the kinetic energy at integer 
 * steps are computed. 
 * One implementation use v = (v(n-1/2)+v(n+1/2))/2 to compute kinetic energy.
 * This is suggested by 
 *   Tetsuya Morishita, Jour. Chem. Phys. Vol 113 (2976--2982) 2000
 * An alternative is from NAMD:
 *           Ek(n) = (Ek(n-1/2) + Ek(n+1/2))/2
 * The code uses the NAMD approach, which seems gives a little better 
 * kinetic energy value (avg velocity gives 149.7 kelvin, avg Ek gives
 * 150.0 kelvin, target temperature is 150 kelvin per degree of freedom)
 *
 *
 * pros: the energy is obtained after shake, which is physically more 
 *       meaningful.
 * cons: the time is inconsistent.
 *
 * Wei Wang,    Arpil 09, 2004
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
#include <errno.h>
#include <string.h>
#include "unit.h"
#include "constant.h"
#include "utilities.h"
#include "helper.h"
#include "shake.h"
#include "settle.h"
#include "linearfit.h"
#include "avgvar.h"
#include "Berendsen_wc.h"
#include "data_collector.h"


MD_Errcode ber_init(struct Berendsen_wc_Tag *ber, 
		    struct Force_Tag *force, 
		    struct Data_Tag *data, 
		    const MD_Double temperature /* Kelvin */ )
{
  ber->force = force;
  ber->data = data;
  ber->temperature = temperature;
  ber->relaxation_time = 1000.0;  /* 1 pico-second */
  ber->desired_kineticE = 0.5 * temperature * KELVIN * 
                         ((MD_Double) data->degree_freedom);
  printf("use Berendsen weak-coupling thermostat (JCP vol 81, 3684)\n");
  printf("  relaxation_time   = %f,\n"
	 "  degree of freedom = %d,\n"
         "  temperature       = %6.3f Kelvin,\n"
         "  desired_kineticE  = %f kcal/mol\n", 
	 ber->relaxation_time, 
	 data->degree_freedom, ber->temperature, 
	 ber->desired_kineticE / KCAL_PER_MOL);
#ifdef NAMD_KINETIC_E
  printf("  -- use NAMD style to compute kinetic energy at integer steps\n");
#endif

  if(ber->relaxation_time <= 10.0 * data->timestep) {
    printf("relaxation time is probably too small, thermostate may "
           "become unstable\n");
    ber->relaxation_time = 10.0 * data->timestep;
    printf("relaxtion time is modified to %f\n", ber->relaxation_time);
  }

  ber->output_freq = data->outputEnergies; 

  return OK;
} 


MD_Errcode ber_destroy(struct Berendsen_wc_Tag *ber)
{
  ber->force = NULL;
  ber->data = NULL;
  return OK;
}


MD_Errcode ber_run(struct Berendsen_wc_Tag *ber, const MD_Int nsteps,
		   const enum RUN_TYPE run_type)
{
  struct Data_Tag *data = ber->data;
  struct Force_Tag *force = ber->force;
  MD_Atom *atom = data->atom;
  MD_Dvec *vel = data->vel;
  MD_Dvec *pos = data->pos;
  MD_Dvec *f = force->f;
  MD_Double ke = -1.0;
  const MD_Double dt = data->timestep; 
  const MD_Int natoms = data->natoms;
  const MD_Int nmols = natoms / 3;

#if 0
  const MD_Int output_freq = ber->output_freq;
  const MD_Int dump_period = 10000;
  const MD_Int bkpoint_period = 100000; 
#endif

  /* DJH- for output */
  const MD_Int outputfreq = (ber->output_freq > 0 ?
      ber->output_freq : nsteps);
  const MD_Int restartfreq = (data->restartfreq > 0 ?
      data->restartfreq : nsteps);
  const MD_Int dcdfreq = (data->dcdfreq > 0 ?
      data->dcdfreq : nsteps);
  /* DJH- end for output */

  MD_Double tmp;
  MD_Double scale_factor = -1.0;
  MD_Int istep, k;
  struct Data_Collector_Tag* data_collector = NULL;
/*
  struct Shake_Water *sw = NULL;
*/
  struct Settle_Water *rw = NULL;
  MD_Double dielectric_constant;

  if (0 == nsteps) return OK;

  /* sanity check */
  assert(force->model_id >= 0 && force->model_id < NMODELS); 
  if (POL1==force->model_id || POL3==force->model_id || 
      RPOL==force->model_id || SPC==force->model_id) {
    assert(nmols * 3 == natoms);
  }

  printf("\n\n\n-------------------------------------------------------\n");
  printf("Berendsen weak coupling method for constant temperature simulation:\n"
         "  to %s for %d steps ...\n",
         (THERMALIZATION == run_type ? "thermalize": "run"), nsteps);
  printf("  timestep is %f femto second \n", dt);
  printf("  output frequency is %d (output energy)\n", outputfreq);
  printf("  restart frequency is %d (checkpoint simulation) \n", restartfreq);
  printf("  DCD frequency is %d (save trajectory file frame)\n", dcdfreq);
#if 0
  printf("  dump period is %d (dump data) \n", dump_period);
  printf("  break point period is %d (dump image for restart)\n",
         bkpoint_period);
#endif
  printf("-------------------------------------------------------\n");

  if (PRODUCTION == run_type) {
    data_collector = my_malloc(sizeof(struct Data_Collector_Tag), 
			       "data collector");
    data_collector_init(data_collector, data, force, ConstTemp);
    data_collector_assign_pmd(data_collector, ber);
  }

  if (POL1==force->model_id || POL3==force->model_id || 
      RPOL==force->model_id || SPC==force->model_id) {
/*
    sw = my_malloc(sizeof(struct Shake_Water));
    shake_init(sw, data->bond_len, data->bond_angle, O_MASS, H_MASS, 
	       natoms);
*/
    MD_Double ohdist = data->bond_len;
    MD_Double hhdist = sqrt(2.0*ohdist*ohdist*(1.0-cos(data->bond_angle)));
    rw = my_malloc(sizeof(struct Settle_Water), "settle");
    settle_init(rw, O_MASS, H_MASS, hhdist, ohdist, natoms);
  } 

  fflush(NULL);

  /*
   * DJH- output dcd file header
   */
  if (output_dcd_header(data, run_type)) {
    fprintf(stderr, "cannot write dcd file header\n"); /* soft fail */
  }

  /* 
   *start the MD run, leapfrog method + Berendsen weak-coupling
   */
  for (istep = data->firststepnum + 1; 
       istep<= data->firststepnum + nsteps + 1; 
       istep++) {  

    /* compute scale factor */
    ke = compute_KE(vel, atom, natoms);
    scale_factor = sqrt(1.0 + (dt/ber->relaxation_time) * 
			(ber->desired_kineticE / ke - 1.0));
    if (scale_factor < 0.8) scale_factor = 0.8;
    if (scale_factor > 1.25) scale_factor = 1.25;

    /* fprintf(stderr, "ke before kick: %f\n", ke/KCAL_PER_MOL);  */

    ke = 0.0;
    for (k=0; k<natoms; k++) ke += MD_vec_dot(vel[k],vel[k])*atom[k].m;
    data->energy[KE] = 0.25 * ke;  

    if (force_compute(force)) return MD_FAIL;
    if (0 != errno) {
      printf("errno=%d\n", errno);
      perror("something wrong in force_compute");
    }

    /* 1. kick and scale */
    for (k = 0;  k < natoms;  k++) {
      tmp = dt / atom[k].m;
      MD_vec_mul_add(vel[k], f[k], tmp);
      MD_vec_mul(vel[k], scale_factor, vel[k]);
    }
    /*
    printf("ke after  kick: %f\n", compute_KE(vel, atom, natoms)/KCAL_PER_MOL);
    */
	
    /* 2. drift and shake */			  
/*
    if (NULL != sw) shake_prepare(sw, pos);
*/
    if (NULL != rw) settle_prepare(rw, pos);
    for (k = 0;  k < natoms;  k++) {
      MD_vec_mul_add(pos[k], vel[k], dt);
    }
/*
    if (NULL != sw) {
      for (k = 0; k < nmols; k++) {
	if (shake_position(sw, 3*k, pos+3*k, vel+3*k, dt)) {
	  return MD_FAIL;
	}
      } 
    }
*/
   if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
        if (settle1(rw, 3*k, pos+3*k, vel+3*k, dt)) {
          return MD_FAIL;
        }
      }
    }

    if (0 != errno) {
      printf("errno=%d\n", errno);
      perror("something wrong in force_compute");
      errno = 0;
    }
    /* 3. output, get the energy at previous step (istep-1) */
    if (0 == (istep-1) % outputfreq) {
      ke = 0.0;
      for (k=0; k<natoms; k++) ke += MD_vec_dot(vel[k],vel[k])*atom[k].m;
      data->energy[KE] += 0.25 * ke;  
      data->energy[TE] = data->energy[KE] + data->energy[PE];
      data_print_energy(data, istep-1);
      if (!conserve_linear_momentum(data)) return MD_FAIL;

      if (NULL != data_collector) {
	data_collector_update(data_collector, istep-1); /* collect less frenquently */
      }
     
#if 0
      if (0 != errno) {
        printf("errno=%d\n", errno);
        perror("something wrong in force_compute");
      }
#endif
    }

    /* DJH- check for dcd file output
     *
     * data collection output corresponds with DCD frames
     */
    if (istep - 1 - data->firststepnum > 0 &&
        (istep - 1 - data->firststepnum) % dcdfreq == 0) {
      if (NULL != data_collector) {
        data_collector_output(data_collector);
        if (POL1==force->model_id || POL3==force->model_id || 
            RPOL==force->model_id || SPC==force->model_id) {
          MD_Double msqr=avgvar_get_avg(data_collector->avgvar[
              data_collector->av_index[AV_Msqr]]);
          dielectric_constant= 1.0 + 4.0*Pi/(3.0*force_compute_volume(force)) *
            (msqr*DEBYE*DEBYE) / (ber->temperature*KELVIN) * COULOMB_SQR;
          printf("dielectric constant: %f\n", dielectric_constant);
        }
        fflush(stdout);
      }
#if 0
      /* for the computation of diffusion constant */
      if (sprintf(filename, "output/pos_%d.dat", istep-1) + 1 > 
          (int)sizeof(MD_String)) { 
        fprintf(stdout, "filename is too small\n");
        return MD_FAIL;
      }
      if (bindump_vec_array(pos, natoms, filename)) {
        fprintf(stdout, "cannot output %s, disk full ?? \n", filename);
      }
#endif
    }

    /*
     * DJH- check for restart file output
     */
    if (istep - 1 - data->firststepnum > 0 &&
        (istep - 1 - data->firststepnum)  % restartfreq == 0) {
      if (data_bindump_image(data, istep-1)) {
        fprintf(stderr, "cannot dump system image \n"); /* soft fail */
      }
    }
#if 0
    /* ------- dump velocity at every timestep (to compute VAF) ------- */
    if (istep % 1 == 0) {
      if (sprintf(filename, "output/vel_%d.dat", istep) + 1 > 
          (int)sizeof(MD_String)) { 
        fprintf(stdout, "filename is too small\n");
        return MD_FAIL;
      }
      if (bindump_vec_array(data->vel, data->natoms, filename)) {
        fprintf(stdout, "error output velocity\n");
      }
    }	
#endif

    /* fflush(NULL); */

    /* 4. check */
    if (0 != errno) {
      printf("errno=%d\n", errno);
      perror("something wrong in force_compute");
    }
  } /* end of istep */

  /*
   * DJH- done writing dcd file
   */
  if (output_dcd_done(data)) {
    fprintf(stderr, "cannot finish with dcd file writing\n"); /* soft fail */
  }

  data->firststepnum += nsteps;

#if 0
  /* summarizing output, in case # of steps is not a multiplier of 
     a output frenquency */
  if ( (POL1==force->model_id || POL3 == force->model_id || 
	RPOL == force->model_id || SPC == force->model_id) && 
       (NULL != data_collector) ){
    MD_Double msqr = avgvar_get_avg(data_collector->avgvar[
                                    data_collector->av_index[AV_Msqr]]);
    dielectric_constant = 1.0 + 4.0*Pi/(3.0*force_compute_volume(force)) *
      (msqr * (DEBYE * DEBYE)) /  (ber->temperature * KELVIN) * COULOMB_SQR;
    printf("dielectric constant: %f\n", dielectric_constant);
  }
#endif

  if (NULL != data_collector) {
    data_collector_destroy(data_collector);
    free(data_collector);
  }

/*
  if (NULL != sw) {
    shake_destroy(sw);
    free(sw);
  }
*/
  if (NULL != rw) {
    settle_destroy(rw);
    free(rw);
  }

  return OK;
}


MD_Double ber_get_pv_over_nkT(const struct Berendsen_wc_Tag *ber)
{
  if (NULL == ber) {
    fprintf(stderr, "wrong function call\n");
    return 0.0;
  } else {
    const MD_Dvec *r = ber->data->pos; /* position */
    const MD_Dvec *f = ber->force->f;   /* force */
    MD_Double virial;
    const MD_Int natoms = ber->data->natoms;
    MD_Int i;

    virial = 0.0;
    for (i = 0; i < natoms; i++) {
      virial += MD_vec_dot(r[i], f[i]);
    }
    virial /= 3.0;

    return 1.0 + virial / ((MD_Double)(natoms * 3) * KELVIN * 
			   ber->temperature);
  }
}
