/*
 * Copyright (C) 2007 by David J. Hardy.  All rights reserved.
 *
 * step.c
 */

#include <stdlib.h>
#include <string.h>
#include "step/intdef.h"
#define DEBUG_WATCH
#undef DEBUG_WATCH
#include "debug/debug.h"

#define DEBUG_VIRIAL
#undef DEBUG_VIRIAL


int step_init(Step *s, StepParam *param)
{
  int32 i, retval;

  /* check parameters */
  if (param->natoms <= 0) return STEP_FAIL;
  if (param->ndegfreedom <= 0) return STEP_FAIL;
  if (param->force_compute == NULL) return STEP_FAIL;
  if (param->atom == NULL) return STEP_FAIL;
  for (i = 0;  i < param->natoms;  i++) {
    if (param->atom[i].m <= 0.0) return STEP_FAIL;
  }

  /* set internals, allocate memory */
  memset(s, 0, sizeof(Step));
  s->tempkonst = 2.0 / (param->ndegfreedom * MD_BOLTZMAN);
  s->param = param;
  s->scal_inv_mass = (double *) malloc(param->natoms * sizeof(double));
  if (s->scal_inv_mass == NULL) return STEP_FAIL;
  for (i = 0;  i < param->natoms;  i++) {
    s->scal_inv_mass[i] = MD_FORCE_CONST / param->atom[i].m;
  }
  s->half_vel = (MD_Dvec *) calloc(param->natoms, sizeof(MD_Dvec));
  if (s->half_vel == NULL) return STEP_FAIL;
  random_initseed(&(s->random), param->random_seed);

  switch (param->method) {
    case STEP_VERLET:
      s->compute = step_compute_verlet;
      s->done = step_done_verlet;
      retval = step_init_verlet(s, param);
      break;
    case STEP_SHADOW:
      s->compute = step_compute_shadow;
      s->done = step_done_shadow;
      retval = step_init_shadow(s, param);
      break;
    case STEP_TEMPBATH:
      s->compute = step_compute_tempbath;
      s->done = step_done_tempbath;
      retval = step_init_tempbath(s, param);
      break;
    case STEP_NHEXP:
      s->compute = step_compute_nosehoover_explicit;
      s->done = step_done_nosehoover_explicit;
      retval = step_init_nosehoover_explicit(s, param);
      break;
    case STEP_CGMIN:
      s->compute = step_compute_cgmin;
      s->done = step_done_cgmin;
      retval = step_init_cgmin(s, param);
      break;
    case STEP_DRUDE:
      s->compute = step_compute_drude_thermal;
      s->done = step_done_drude_thermal;
      retval = step_init_drude_thermal(s, param);
      break;
    case STEP_DRUTEST:
      s->compute = step_compute_drude_test;
      s->done = step_done_drude_test;
      retval = step_init_drude_test(s, param);
      break;
    default:
      /* unknown method */
      retval = STEP_FAIL;
  }

  return 0;
}


void step_done(Step *s)
{
  s->done(s);  /* cleanup particular integration method */
  free(s->scal_inv_mass);
  free(s->half_vel);
}


int step_compute(Step *s, StepSystem *sys, int32 numsteps)
{
  return s->compute(s, sys, numsteps);
}


int step_find_reductions(Step *s, StepSystem *sys)
{
#ifdef DEBUG_VIRIAL
  double v[9] = { 0.0 };
#endif

  const MD_Atom *atom = s->param->atom;
  const MD_Dvec *vel = sys->vel;
  const MD_Dvec *half_vel = s->half_vel;
  const double *scal_inv_mass = s->scal_inv_mass;
  const MD_Dvec *f = sys->force;
  const double half_dt = 0.5 * s->param->timestep;
  double e_sum = 0.0;
  MD_Dvec lm_sum = { 0.0, 0.0, 0.0 };
  double kv0_xx = 0.0;
  double kv0_xy = 0.0;
  double kv0_xz = 0.0;
  double kv0_yy = 0.0;
  double kv0_yz = 0.0;
  double kv0_zz = 0.0;
  double kv1_xx = 0.0;
  double kv1_xy = 0.0;
  double kv1_xz = 0.0;
  double kv1_yy = 0.0;
  double kv1_yz = 0.0;
  double kv1_zz = 0.0;
  double konst;
  MD_Dvec next_half_vel;
  const int32 natoms = s->param->natoms;
  int32 i;

  for (i = 0;  i < natoms;  i++) {
    /* compute kinetic energy */
    e_sum += (vel[i].x * vel[i].x + vel[i].y * vel[i].y
        + vel[i].z * vel[i].z) * atom[i].m;

    /* compute linear momentum */
    lm_sum.x += atom[i].m * vel[i].x;
    lm_sum.y += atom[i].m * vel[i].y;
    lm_sum.z += atom[i].m * vel[i].z;

    if (s->param->method != STEP_DRUDE) {
      /*
       * compute kinetic contribution to virial
       *
       * average both backwards and forwards half velocities, like NAMD does
       */
      kv0_xx += atom[i].m * half_vel[i].x * half_vel[i].x;
      kv0_xy += atom[i].m * half_vel[i].x * half_vel[i].y;
      kv0_xz += atom[i].m * half_vel[i].x * half_vel[i].z;
      kv0_yy += atom[i].m * half_vel[i].y * half_vel[i].y;
      kv0_yz += atom[i].m * half_vel[i].y * half_vel[i].z;
      kv0_zz += atom[i].m * half_vel[i].z * half_vel[i].z;

      konst = half_dt * scal_inv_mass[i];
      next_half_vel.x = vel[i].x + konst * f[i].x;
      next_half_vel.y = vel[i].y + konst * f[i].y;
      next_half_vel.z = vel[i].z + konst * f[i].z;

      kv1_xx += atom[i].m * next_half_vel.x * next_half_vel.x;
      kv1_xy += atom[i].m * next_half_vel.x * next_half_vel.y;
      kv1_xz += atom[i].m * next_half_vel.x * next_half_vel.z;
      kv1_yy += atom[i].m * next_half_vel.y * next_half_vel.y;
      kv1_yz += atom[i].m * next_half_vel.y * next_half_vel.z;
      kv1_zz += atom[i].m * next_half_vel.z * next_half_vel.z;
    }
  }
  sys->kinetic_energy = 0.5 * MD_KCAL_MOL * e_sum;
  sys->temperature = s->tempkonst * sys->kinetic_energy;

  sys->linear_momentum = lm_sum;

  sys->kinetic_virial[VIRIAL_XX] = 0.5 * MD_KCAL_MOL * (kv0_xx + kv1_xx);
  sys->kinetic_virial[VIRIAL_XY] = 0.5 * MD_KCAL_MOL * (kv0_xy + kv1_xy);
  sys->kinetic_virial[VIRIAL_XZ] = 0.5 * MD_KCAL_MOL * (kv0_xz + kv1_xz);
  sys->kinetic_virial[VIRIAL_YX] = 0.5 * MD_KCAL_MOL * (kv0_xy + kv1_xy);
  sys->kinetic_virial[VIRIAL_YY] = 0.5 * MD_KCAL_MOL * (kv0_yy + kv1_yy);
  sys->kinetic_virial[VIRIAL_YZ] = 0.5 * MD_KCAL_MOL * (kv0_yz + kv1_yz);
  sys->kinetic_virial[VIRIAL_ZX] = 0.5 * MD_KCAL_MOL * (kv0_xz + kv1_xz);
  sys->kinetic_virial[VIRIAL_ZY] = 0.5 * MD_KCAL_MOL * (kv0_yz + kv1_yz);
  sys->kinetic_virial[VIRIAL_ZZ] = 0.5 * MD_KCAL_MOL * (kv0_zz + kv1_zz);

#ifdef DEBUG_VIRIAL
  v[0] = 0.5 * MD_KCAL_MOL * kv0_xx;
  v[1] = 0.5 * MD_KCAL_MOL * kv0_xy;
  v[2] = 0.5 * MD_KCAL_MOL * kv0_xz;
  v[3] = 0.5 * MD_KCAL_MOL * kv0_xy;
  v[4] = 0.5 * MD_KCAL_MOL * kv0_yy;
  v[5] = 0.5 * MD_KCAL_MOL * kv0_yz;
  v[6] = 0.5 * MD_KCAL_MOL * kv0_xz;
  v[7] = 0.5 * MD_KCAL_MOL * kv0_yz;
  v[8] = 0.5 * MD_KCAL_MOL * kv0_zz;
  printf("half-kinetic virial: %g %g %g %g %g %g %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
  v[0] = 0.5 * MD_KCAL_MOL * kv1_xx;
  v[1] = 0.5 * MD_KCAL_MOL * kv1_xy;
  v[2] = 0.5 * MD_KCAL_MOL * kv1_xz;
  v[3] = 0.5 * MD_KCAL_MOL * kv1_xy;
  v[4] = 0.5 * MD_KCAL_MOL * kv1_yy;
  v[5] = 0.5 * MD_KCAL_MOL * kv1_yz;
  v[6] = 0.5 * MD_KCAL_MOL * kv1_xz;
  v[7] = 0.5 * MD_KCAL_MOL * kv1_yz;
  v[8] = 0.5 * MD_KCAL_MOL * kv1_zz;
  printf("half-kinetic virial: %g %g %g %g %g %g %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
#endif

  return 0;
}


/*
 * computes \sum_i m_i v_i^T v_i
 * (twice kinetic energy, no unit conversion)
 */
void step_find_reductions_kenergy(double *e_sum,
    const MD_Dvec *v, const MD_Atom *atom, int32 natoms)
{
  int32 i;
  *e_sum = 0.0;
  for (i = 0;  i < natoms;  i++) {
    *e_sum += (v[i].x*v[i].x + v[i].y*v[i].y + v[i].z*v[i].z) * atom[i].m;
  }
}


/*
 * computes \sum_i m_i v_i
 * (linear momentum of system)
 */
void step_find_reductions_linmo(MD_Dvec *linmo,
    const MD_Dvec *v, const MD_Atom *atom, int32 natoms)
{
  int32 i;
  linmo->x = 0.0;
  linmo->y = 0.0;
  linmo->z = 0.0;
  for (i = 0;  i < natoms;  i++) {
    linmo->x += atom[i].m * v[i].x;
    linmo->y += atom[i].m * v[i].y;
    linmo->z += atom[i].m * v[i].z;
  }
}


/*
 * computes \sum r_i \cross m_i v_i
 * (angular momentum)
 */
void step_find_reductions_angmo(MD_Dvec *angmo,
    const MD_Dvec *r, const MD_Dvec *v, const MD_Atom *atom, int32 natoms)
{
  int32 i;
  angmo->x = 0.0;
  angmo->y = 0.0;
  angmo->z = 0.0;
  for (i = 0;  i < natoms;  i++) {
    angmo->x += atom[i].m * (r[i].y*v[i].z - r[i].z*v[i].y);
    angmo->y += atom[i].m * (r[i].z*v[i].x - r[i].x*v[i].z);
    angmo->z += atom[i].m * (r[i].x*v[i].y - r[i].y*v[i].x);
  }
}


/*
 * computes \sum_i m_i v_i v_i^T
 * (kinetic contribution to the virial, no unit conversion)
 */
void step_find_reductions_kvirial(double u_kv[VIRIAL_UPPER_LEN],
    const MD_Dvec *v, const MD_Atom *atom, int32 natoms)
{
  int32 i;
  u_kv[VIRIAL_UPPER_XX] = 0.0;
  u_kv[VIRIAL_UPPER_XY] = 0.0;
  u_kv[VIRIAL_UPPER_XZ] = 0.0;
  u_kv[VIRIAL_UPPER_YY] = 0.0;
  u_kv[VIRIAL_UPPER_YZ] = 0.0;
  u_kv[VIRIAL_UPPER_ZZ] = 0.0;
  for (i = 0;  i < natoms;  i++) {
    u_kv[VIRIAL_UPPER_XX] += atom[i].m * v[i].x * v[i].x;
    u_kv[VIRIAL_UPPER_XY] += atom[i].m * v[i].x * v[i].y;
    u_kv[VIRIAL_UPPER_XZ] += atom[i].m * v[i].x * v[i].z;
    u_kv[VIRIAL_UPPER_YY] += atom[i].m * v[i].y * v[i].y;
    u_kv[VIRIAL_UPPER_YZ] += atom[i].m * v[i].y * v[i].z;
    u_kv[VIRIAL_UPPER_ZZ] += atom[i].m * v[i].z * v[i].z;
  }
}


int step_set_random_velocities(Step *s, StepSystem *sys, double init_temp)
{
  const double kbtemp = MD_BOLTZMAN * MD_ENERGY_CONST * init_temp;
  double sqrt_kbtemp_div_mass;
#ifndef STEP_ALT_INITTEMP
  double rnum;
#endif
  const MD_Atom *atom = s->param->atom;
  MD_Dvec *vel = sys->vel;
  Random *r = &(s->random);
  const int32 natoms = s->param->natoms;
  int32 n, k;

  /* make sure initial temperature is valid */
  if (init_temp < 0.0) return STEP_FAIL;

  for (n = 0;  n < natoms;  n++) {
    sqrt_kbtemp_div_mass = sqrt(kbtemp / atom[n].m);

#ifndef STEP_ALT_INITTEMP
    /*
     * The following method and comments taken from NAMD WorkDistrib.C:
     *
     * //  The following comment was stolen from X-PLOR where
     * //  the following section of code was adapted from.
     *
     * //  This section generates a Gaussian random
     * //  deviate of 0.0 mean and standard deviation RFD for
     * //  each of the three spatial dimensions.
     * //  The algorithm is a "sum of uniform deviates algorithm"
     * //  which may be found in Abramowitz and Stegun,
     * //  "Handbook of Mathematical Functions", pg 952.
     */
    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].x = sqrt_kbtemp_div_mass * rnum;

    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].y = sqrt_kbtemp_div_mass * rnum;

    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].z = sqrt_kbtemp_div_mass * rnum;
#else
    /*
     * Alternate method from NAMD Sequencer.C:
     */
    vel[n].x = sqrt_kbtemp_div_mass * random_gaussian(r);
    vel[n].y = sqrt_kbtemp_div_mass * random_gaussian(r);
    vel[n].z = sqrt_kbtemp_div_mass * random_gaussian(r);
#endif
  }
  return 0;
}


int step_remove_com_motion(Step *s, StepSystem *sys)
{
  const MD_Atom *atom = s->param->atom;
  MD_Dvec *vel = sys->vel;
  const int32 natoms = s->param->natoms;
  int32 i;
  MD_Dvec mv = { 0.0, 0.0, 0.0 };  /* accumulate net momentum */
  double mass = 0.0;               /* accumulate total mass */

  /* compute net momentum and total mass */
  for (i = 0;  i < natoms;  i++) {
    mv.x += atom[i].m * vel[i].x;
    mv.y += atom[i].m * vel[i].y;
    mv.z += atom[i].m * vel[i].z;
    mass += atom[i].m;
  }

  /* scale net momentum by total mass */
  mv.x /= mass;
  mv.y /= mass;
  mv.z /= mass;

  /* remove from atom velocities */
  for (i = 0;  i < natoms;  i++) {
    vel[i].x -= mv.x;
    vel[i].y -= mv.y;
    vel[i].z -= mv.z;
  }
  return 0;
}
