/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/*
 * constant energy molecular dynamics simulation
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <math.h>
#include <errno.h>
#include "data.h"
#include "fnonbond.h"
#include "dsolvers.h"
#include "standEwald.h"
#include "force.h"
#include "utilities.h"
#include "helper.h"
#include "shake.h"
#include "rattle.h"
#include "settle.h"
#include "linearfit.h"
#include "avgvar.h"
#include "data_collector.h"
#include "constEnergy.h"
#include "constant.h"
#include "unit.h"

/*
#define SHAKE
*/

MD_Errcode constEnergy_MD_run(struct Force_Tag *force, const MD_Int nsteps,
			      const enum RUN_TYPE run_type)
{
  struct Data_Tag *data = force->data;
  MD_Atom *atom = data->atom;
  MD_Dvec *vel = data->vel;
  MD_Dvec *pos = data->pos;
  MD_Dvec *f = force->f;

  MD_Double dt = data->timestep; 
  MD_Double half_dt = 0.5 * dt; 

  const MD_Int natoms = data->natoms;

#if 0
  const MD_Int output_freq = 100;
  const MD_Int dump_period = 100000;  
  const MD_Int bkpoint_period = 1000000; /* 400ps */
#endif

  /* DJH- for output */
  const MD_Int outputfreq = (data->outputEnergies > 0 ?
      data->outputEnergies : nsteps);
  const MD_Int restartfreq = (data->restartfreq > 0 ?
      data->restartfreq : nsteps);
  const MD_Int dcdfreq = (data->dcdfreq > 0 ?
      data->dcdfreq : nsteps);
  /* DJH- end for output */

  MD_Int istep, k;
  MD_Double tmp;
  struct Data_Collector_Tag* data_collector = NULL;
  const MD_Int nmols = natoms / 3;
#ifdef SHAKE
  struct Shake_Water *sw = NULL;
#else
  struct Settle_Water *rw = NULL;
  /*
  struct Rattle_Water *rw = NULL;
  */
#endif

  if (0 == nsteps) return OK;

  /* sanity check */
  assert(force->model_id >= 0 && force->model_id < NMODELS); 
  if (POL1==data->model_id || POL3==force->model_id || 
      RPOL==force->model_id || SPC==force->model_id) {
    assert(nmols * 3 == natoms); /* water molecule systems */
  }

  fprintf(stderr, "run MD\n");
  printf("\n\n-------------------------------------------------------\n");
  printf("Constant energy simulation (velocity Verlet): to %s for %d steps ...\n", 
	 (THERMALIZATION == run_type ? "thermalize": "run"), nsteps);
  printf("  timestep is %f femto second \n", dt);
  printf("  output frequency is %d (output energy)\n", outputfreq);
  printf("  restart frequency is %d (checkpoint simulation) \n", restartfreq);
  printf("  DCD frequency is %d (save trajectory file frame)\n", dcdfreq);
  printf("-------------------------------------------------------\n");

  /*dump_pos_vel(force, "posvel.dat", 0); */
  if (PRODUCTION == run_type) {
    data_collector = my_malloc(sizeof(struct Data_Collector_Tag), 
			       "data collector");
    data_collector_init(data_collector, data, force, ConstEnergy);
  }

  if (POL1==data->model_id || POL3==force->model_id || 
      RPOL==force->model_id || SPC==force->model_id) {
#ifdef SHAKE
    sw = malloc(sizeof(struct Shake_Water));
    assert(NULL != sw);
    shake_init(sw, data->bond_len, data->bond_angle, O_MASS, H_MASS, 
	       natoms); 
#else
    /*
    rw = my_malloc(sizeof(struct Rattle_Water), "rattle");
    rattle_init(rw, data->bond_len, data->bond_angle, O_MASS, H_MASS, 
		natoms);
    */
    { MD_Double ohdist = data->bond_len;
      MD_Double hhdist = sqrt(2.0*ohdist*ohdist*(1.0-cos(data->bond_angle)));
      rw = my_malloc(sizeof(struct Settle_Water), "settle");
      settle_init(rw, O_MASS, H_MASS, hhdist, ohdist, natoms);
    }
#endif
  } 

  if (force_compute(force)) return MD_FAIL;
  compute_KETE(data);   

#if 0
  /* ad hoc, to make the total energy as desired value ============ */
  {
    const MD_Double desired_tot_e = -1723.0 * KCAL_PER_MOL;
    MD_Double fac = (desired_tot_e - data->energy[TE]) 
                              / data->energy[KE] + 1.0;
    MD_Int i;

    assert(fac > 0.0);
    fac = sqrt(fac);
    printf("fac=%f\n", fac);
    for(i = 0; i < data->natoms; i++) {
      MD_vec_mul(vel[i], fac, vel[i]);
    }
  }
  compute_KETE(data); 
#endif  

  /*force_output(force);*/ 

  data_print_energy(data, data->firststepnum); 

  /*force_dump(force); */

#ifndef PME_CONSERVE_ENERGY 
  if (!conserve_linear_momentum(data)) return MD_FAIL; 
#endif

  /*
  exit(1);
  */

/*
  if (NULL != data_collector) {
    data_collector_update(data_collector, data->firststepnum);
  }
*/

  /*
   * DJH- output dcd file header
   */
  if (output_dcd_header(data, run_type)) {
    fprintf(stderr, "cannot write dcd file header\n"); /* soft fail */
  }

  /* 
   *start the MD run, velocity Verlet integrator
   */
  for (istep = data->firststepnum + 1; istep <= data->firststepnum + nsteps; 
       istep++) {     
    /* printf("step %d\n", istep); */
    /* fprintf(stderr, "half kick\n"); */
    /*
    printf("ke before 1stkick: %20.15f\n",
    compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */
    for (k = 0;  k < natoms;  k++) {
      tmp = half_dt / atom[k].m;
      MD_vec_mul_add(vel[k], f[k], tmp);
    }
    /*
    printf("ke after  1st kick: %20.15f\n", 	   
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */
#ifndef PME_CONSERVE_ENERGY
    if (!conserve_linear_momentum(data)) return MD_FAIL; 
#endif

    /* fprintf(stderr, "drift\n"); */
#ifdef SHAKE
    if (NULL != sw)  shake_prepare(sw, pos);
#else
    /*
    if (NULL != rw) rattle_prepare(rw, pos);
    */
    if (NULL != rw) settle_prepare(rw, pos);
#endif

    for (k=0; k<natoms; k++) MD_vec_mul_add(pos[k], vel[k], dt);

#ifdef SHAKE
    if (NULL != sw) {
      for (k = 0; k < nmols; k++) {
	if (shake(sw, 3*k, pos+3*k, vel+3*k, dt)) {
	  return MD_FAIL;
	}
	/*
	  printf("post-shake, pos[%d]: %f, %f, %f\n", k, 
               pos[k].x, pos[k].y, pos[k].z);
	*/
      } 
    }
#else
    /*
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
	if (rattle1(rw, 3*k, pos+3*k, vel+3*k, dt)) {
	  return MD_FAIL;
	}   
      }   
    }
    */
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
	if (settle1(rw, 3*k, pos+3*k, vel+3*k, dt)) {
	  return MD_FAIL;
	}   
      }   
    }
#endif

    /*
    printf("ke after rattle1/shake: %20.15f\n", 
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */

    /*find_nearest_neibr(force);*/

    /* fprintf(stderr, "compute force\n"); */
    if (force_compute(force))  return MD_FAIL;
    /*force_output(force); */

    /*
    printf("ke before 2nd kick: %20.15f\n", 
	   compute_KE(vel, data->atom, natoms)/KCAL_PER_MOL);
    */
    /* fprintf(stderr, "half kick\n"); */
    for (k = 0;  k < natoms;  k++) {
      tmp = half_dt / atom[k].m;
      MD_vec_mul_add(vel[k], f[k], tmp);
    }

    /*
    fprintf(stderr, "before rattle2, ke = %20.15f\n", 
	    compute_KE(vel, atom, natoms)/KCAL_PER_MOL);
    */

#ifndef SHAKE
    /*
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
	if (rattle2(rw, pos+3*k, vel+3*k, dt)) {
	  fprintf(stderr, "rattle2 failed for molecule %d\n", k); 
	  return MD_FAIL;
	}
      }
    }
    */
    if (NULL != rw) {
      for (k = 0; k < nmols; k++) {
	if (settle2(rw, pos+3*k, vel+3*k, dt)) {
	  fprintf(stderr, "settle2 failed for molecule %d\n", k); 
	  return MD_FAIL;
	}
      }
    }
#endif

    /*
    fprintf(stderr, "after  rattle2, ke = %20.15f\n", 
	    compute_KE(vel, atom, natoms)/KCAL_PER_MOL);
    */

    compute_KETE(data);

    /*
    {
      MD_Double vol = force->systemsize.x * force->systemsize.y *
	force->systemsize.z;
      printf("pressure = %f\n", 
	     compute_pressure(f, force->wrapped_pos, vel, atom, 
			      vol, natoms));
    }
    */


    /*
     * DJH- check for energy output
     */
    if (0 == istep % outputfreq) {    
      data_print_energy(data, istep); 
      fflush(stdout); 
#ifndef PME_CONSERVE_ENERGY
      if (!conserve_linear_momentum(data)) return MD_FAIL; 
#endif

      /* 
	for (k=0; k<N_mda_av_components; k++) {
	avgvar_dump_data(data_collector->avgvar_components[k], stdout);
	}
      */
      if (NULL != data_collector) data_collector_update(data_collector, istep);
    }

    /* DJH- check for dcd file output
     *
     * data collection output corresponds with DCD frames
     */
    if ((istep - data->firststepnum) % dcdfreq == 0) {
      if (NULL != data_collector) data_collector_output(data_collector);
      if (output_dcd_frame(data, istep)) {
        fprintf(stderr, "cannot write dcd frame\n");  /* soft fail */
      }

#if 0
    /* if ((istep - data->firststepnum) % dump_period == 0) { */
      if (NULL != data_collector) data_collector_output(data_collector);
      if (sprintf(filename, "output/pos_%d.dat", istep) + 1 >
          (MD_Int)sizeof(MD_String)) {
        fprintf(stdout, "filename is too small\n");
        return MD_FAIL;
      }
      if (bindump_vec_array(pos, natoms, filename)) {
	fprintf(stderr, "cannot write file %s, disk full ?? \n", filename);
      }
#endif
#ifdef PME_CONSERVE_ENERGY
      {
	MD_Dvec ptot = {0.0, 0.0, 0.0};
	for(k=0; k<natoms; k++) MD_vec_mul_add(ptot, vel[k], atom[k].m);
	printf("total linear momentum: %g, %g, %g\n", ptot.x, ptot.y, ptot.z);
      }
#endif
    }

    /*
     * DJH- check for restart file output
     */
    if ((istep - data->firststepnum) % restartfreq == 0) {
      if (data_bindump_image(data, istep)) {
	fprintf(stderr, "cannot dump system image \n"); /* soft fail */
      }
    }

    if (0 != errno) {
      perror("something wrong in force_compute");
    }
    /*fflush(NULL);*/
  } /* end of istep */

  /*
   * DJH- done writing dcd file
   */
  if (output_dcd_done(data)) {
    fprintf(stderr, "cannot finish with dcd file writing\n"); /* soft fail */
  }

  data->firststepnum += nsteps;

  if (NULL != data_collector) {
    data_collector_destroy(data_collector);
    free(data_collector);
  }

#ifdef SHAKE
  if (NULL != sw) {
    shake_destroy(sw);
    free(sw);
  }
#else
  /*
  if (NULL != rw) {
    rattle_destroy(rw);
    free(rw);
  }
  */
  if (NULL != rw) {
    settle_destroy(rw);
    free(rw);
  }
#endif

  return OK;
}

