/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * update.c
 */

#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <math.h>
#include "mdapi/mdengine.h"
#include "force/force.h"
#include "simen/engine.h"
#include "debug/debug.h"


/* prototypes for internal functions */
static int check(Engine *);
static void remove_com_motion(Engine *);
static void random_velocities(Engine *);


int32 simen_update(MD_Front *front)
{
  Engine *eng = MD_engine_data(front);
  Param *param = &(eng->param);
  Force *force = &(eng->force);
  ForceParam *p = &(eng->f_param);
  ForceResult *r = &(eng->f_result);
  ForceEnergy *e = &(eng->result.fe);

  /* make sure number of atoms is consistent */
  eng->natoms = eng->atom->attrib.len;
  if (eng->natoms <= 0
      || eng->natoms != eng->pos->attrib.len
      || eng->natoms != eng->vel->attrib.len) {
    return MD_error(front, eng->err_param);
  }

  /* make sure force array is correct length */
  if (eng->natoms != eng->force_engdata->attrib.len) {

    /* resize force array to correct length */
    if (MD_engdata_setlen(front, eng->force_engdata, eng->natoms)) {
      return MD_FAIL;
    }

    /* make sure force result uses force array as its buffer space */
    memset(r, 0, sizeof(ForceResult));
    r->f = (MD_Dvec *) (eng->force_engdata->buf);
  }

  /* check system values */
  if (check(eng)) {
    return MD_error(front, eng->err_param);
  }

  /* deal with temperature update */
  if (eng->init_temp_engdata->attrib.access & MD_MODIFY) {

    /* seed random number generator if not already done */
    if ( ! eng->is_prng_seeded ) {
      random_initseed(&(eng->random), eng->seed);
      eng->is_prng_seeded = 1;
    }

    /* set initial velocities using inittemp */
    random_velocities(eng);

    /* acknowledge temperature modification to interface */
    if (MD_engdata_ackmod(front, eng->init_temp_engdata)) {
      return MD_FAIL;
    }

    /* set velocity modification flag to trip next section init */
    eng->vel->attrib.access |= MD_MODIFY;
  }

  /* deal with velocity update */
  if (eng->vel->attrib.access & MD_MODIFY) {

    /* compute number of degrees of freedom for system */
    eng->ndegfreedom = 3 * eng->natoms;

    /* allow center-of-mass motion? */
    if (strcasecmp(eng->com_motion, "no") == 0
        || strcasecmp(eng->com_motion, "off") == 0) {

      /* remove center of mass motion from initial velocities */
      remove_com_motion(eng);

      /* this reduces number of degrees of freedom (by one particle) */
      eng->ndegfreedom -= 3;
    }
    else if (strcasecmp(eng->com_motion, "yes") != 0
        && strcasecmp(eng->com_motion, "on") != 0) {
      return MD_error(front, eng->err_force_init);
    }

    /* compute constants needed for reductions */
    eng->tempkonst = 2.0 / (eng->ndegfreedom * MD_BOLTZMAN);

    /* acknowledge velocity modification to interface */
    if (MD_engdata_ackmod(front, eng->vel)) {
      return MD_FAIL;
    }
  }

  /* deal with position update */
  if (eng->pos->attrib.access & MD_MODIFY) {

    /*** nothing to do ***/

    /* acknowledge position modification to interface */
    if (MD_engdata_ackmod(front, eng->pos)) {
      return MD_FAIL;
    }
  }

  /* see if we need to (re-)init force library */
  if ((eng->atomprm->attrib.access & MD_MODIFY)
      || (eng->bondprm->attrib.access & MD_MODIFY)
      || (eng->angleprm->attrib.access & MD_MODIFY)
      || (eng->dihedprm->attrib.access & MD_MODIFY)
      || (eng->imprprm->attrib.access & MD_MODIFY)
      || (eng->nbfixprm->attrib.access & MD_MODIFY)
      || (eng->atom->attrib.access & MD_MODIFY)
      || (eng->bond->attrib.access & MD_MODIFY)
      || (eng->angle->attrib.access & MD_MODIFY)
      || (eng->dihed->attrib.access & MD_MODIFY)
      || (eng->impr->attrib.access & MD_MODIFY)
      || (eng->excl->attrib.access & MD_MODIFY)
      || (eng->param_engdata->attrib.access & MD_MODIFY)) {

    /* must first destroy previous force */
    force_done(force);
    if (force_init(force)) {
      return MD_error(front, eng->err_force_init);
    }

    /* zero force param */
    memset(p, 0, sizeof(ForceParam));

    /* set flags */
    p->flags = FORCE_ALL;

    /* setup force params */
    p->atomprm = (MD_AtomPrm *) (eng->atomprm->buf);
    p->bondprm = (MD_BondPrm *) (eng->bondprm->buf);
    p->angleprm = (MD_AnglePrm *) (eng->angleprm->buf);
    p->dihedprm = (MD_TorsPrm *) (eng->dihedprm->buf);
    p->imprprm = (MD_TorsPrm *) (eng->imprprm->buf);
    p->nbfixprm = (MD_NbfixPrm *) (eng->nbfixprm->buf);
    p->atomprm_len = eng->atomprm->attrib.len;
    p->bondprm_len = eng->bondprm->attrib.len;
    p->angleprm_len = eng->angleprm->attrib.len;
    p->dihedprm_len = eng->dihedprm->attrib.len;
    p->imprprm_len = eng->imprprm->attrib.len;
    p->nbfixprm_len = eng->nbfixprm->attrib.len;

    p->atom = (MD_Atom *) (eng->atom->buf);
    p->bond = (MD_Bond *) (eng->bond->buf);
    p->angle = (MD_Angle *) (eng->angle->buf);
    p->dihed = (MD_Tors *) (eng->dihed->buf);
    p->impr = (MD_Tors *) (eng->impr->buf);
    p->excl = (MD_Excl *) (eng->excl->buf);
    p->atom_len = eng->atom->attrib.len;
    p->bond_len = eng->bond->attrib.len;
    p->angle_len = eng->angle->attrib.len;
    p->dihed_len = eng->dihed->attrib.len;
    p->impr_len = eng->impr->attrib.len;
    p->excl_len = eng->excl->attrib.len;

    /* can't handle non-orthogonal cells */
    if (param->cellBasisVector1.y != 0.0
        || param->cellBasisVector1.z != 0.0
        || param->cellBasisVector2.x != 0.0
        || param->cellBasisVector2.z != 0.0
        || param->cellBasisVector3.x != 0.0
        || param->cellBasisVector3.y != 0.0) {
      return MD_error(front, eng->err_force_init);
    }
    if (param->cellBasisVector1.x != 0.0) {
      p->flags |= FORCE_X_PERIODIC;
      p->xlen = param->cellBasisVector1.x;
      p->center = param->cellOrigin;
    }
    if (param->cellBasisVector2.y != 0.0) {
      p->flags |= FORCE_Y_PERIODIC;
      p->ylen = param->cellBasisVector2.y;
      p->center = param->cellOrigin;
    }
    if (param->cellBasisVector3.z != 0.0) {
      p->flags |= FORCE_Z_PERIODIC;
      p->zlen = param->cellBasisVector3.z;
      p->center = param->cellOrigin;
    }

    /* setup spherical boundary conditions */
    if (strcasecmp(param->sphericalBC, "on") == 0
        || strcasecmp(param->sphericalBC, "yes") == 0) {
      /* turn periodicity off */
      p->flags &= ~FORCE_PERIODIC;
      /* turn spherical boundary conditions on */
      p->flags |= FORCE_SPHERE;
      /*
       * in this case, xlen, ylen, zlen are still used by grid cell
       * algorithm to determine spatial region of molecule
       * (make sure they correspond to extent of sphere)
       */
      p->center = param->sphericalBCCenter;
      p->radius1 = param->sphericalBCr1;
      p->radius2 = param->sphericalBCr2;
      p->konst1 = param->sphericalBCk1;
      p->konst2 = param->sphericalBCk2;
      p->exp1 = param->sphericalBCexp1;
      p->exp2 = param->sphericalBCexp2;
    }
    else if (strcasecmp(param->sphericalBC, "off") != 0
        && strcasecmp(param->sphericalBC, "no") != 0) {
      return MD_error(front, eng->err_force_init);
    }

    /* setup cylindrical boundary conditions */
    if (strcasecmp(param->cylindricalBC, "on") == 0
        || strcasecmp(param->cylindricalBC, "yes") == 0) {
      /* make sure spherical boundaries were not also set */
      if (p->flags & FORCE_SPHERE) {
        return MD_error(front, eng->err_force_init);
      }
      /* turn periodicity off */
      p->flags &= ~FORCE_PERIODIC;
      /* flags depend on axis value */
      if (strcasecmp(param->cylindricalBCAxis, "x") == 0) {
        p->flags |= FORCE_X_CYLINDER;
      }
      else if (strcasecmp(param->cylindricalBCAxis, "y") == 0) {
        p->flags |= FORCE_Y_CYLINDER;
      }
      else if (strcasecmp(param->cylindricalBCAxis, "z") == 0) {
        p->flags |= FORCE_Z_CYLINDER;
      }
      else {
        return MD_error(front, eng->err_force_init);
      }
      /*
       * in this case, xlen, ylen, zlen are still used by grid cell
       * algorithm to determine spatial region of molecule
       * (make sure they correspond to extent of cylinder)
       */
      p->center = param->cylindricalBCCenter;
      p->radius1 = param->cylindricalBCr1;
      p->radius2 = param->cylindricalBCr2;
      p->length1 = param->cylindricalBCl1;
      p->length2 = param->cylindricalBCl2;
      p->konst1 = param->cylindricalBCk1;
      p->konst2 = param->cylindricalBCk2;
      p->exp1 = param->cylindricalBCexp1;
      p->exp2 = param->cylindricalBCexp2;
    }
    else if (strcasecmp(param->cylindricalBC, "off") != 0
        && strcasecmp(param->cylindricalBC, "no") != 0) {
      return MD_error(front, eng->err_force_init);
    }

    /* full direct electrostatics */
    if (strcasecmp(param->fullDirect, "on") == 0
        || strcasecmp(param->fullDirect, "yes") == 0) {
      p->flags |= FORCE_ELEC_DIRECT;
    }
    else if (strcasecmp(param->fullDirect, "off") != 0
        && strcasecmp(param->fullDirect, "no") != 0) {
      return MD_error(front, eng->err_force_init);
    }

    /* set exclusion policy */
    if (strcasecmp(param->exclude, "none") == 0) {
      p->flags |= FORCE_EXCL_NONE;
    }
    else if (strcasecmp(param->exclude, "1-2") == 0) {
      p->flags |= FORCE_EXCL_12;
    }
    else if (strcasecmp(param->exclude, "1-3") == 0) {
      p->flags |= FORCE_EXCL_13;
    }
    else if (strcasecmp(param->exclude, "1-4") == 0) {
      p->flags |= FORCE_EXCL_14;
    }
    else if (strcasecmp(param->exclude, "scaled1-4") == 0) {
      p->flags |= FORCE_EXCL_SCAL14;
    }
    else {
      return MD_error(front, eng->err_force_init);
    }

    /* switching and smoothing */
    if (strcasecmp(param->switching, "on") == 0
        || strcasecmp(param->switching, "yes") == 0) {
      if (p->flags & FORCE_ELEC_DIRECT) {
        p->flags |= FORCE_SWITCH;
      }
      else {
        p->flags |= FORCE_CONTINUOUS;  /* i.e. SWITCH + SMOOTH */
      }
      p->switchdist = param->switchDist;
    }
    else if (strcasecmp(param->switching, "off") != 0
        && strcasecmp(param->switching, "no") != 0) {
      return MD_error(front, eng->err_force_init);
    }

    /* nonbonded params */
    p->cutoff = param->cutoff;
    p->elec_const = MD_COULOMB;
    p->dielectric = param->dielectric;
    p->scaling14 = param->scaling;  /* value of "1-4scaling" */

    /* setup force library */
    if (force_setup(force, p, e, r)) {
      return MD_error(front, eng->err_force_init);
    }

    /* acknowledge engdata modification to interface */
    if (((eng->atomprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->atomprm))
        || ((eng->bondprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->bondprm))
        || ((eng->angleprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->angleprm))
        || ((eng->dihedprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->dihedprm))
        || ((eng->imprprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->imprprm))
        || ((eng->nbfixprm->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->nbfixprm))
        || ((eng->atom->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->atom))
        || ((eng->bond->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->bond))
        || ((eng->angle->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->angle))
        || ((eng->dihed->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->dihed))
        || ((eng->impr->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->impr))
        || ((eng->excl->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->excl))
        || ((eng->param_engdata->attrib.access & MD_MODIFY)
          && MD_engdata_ackmod(front, eng->param_engdata))) {
      return MD_error(front, eng->err_force_init);
    }
  } /* done with force library init */

  return 0;
}


int check(Engine *eng)
{
  const MD_Atom *atom = (MD_Atom *) (eng->atom->buf);
  const int32 n = eng->natoms;
  int32 k;

  if (eng->init_temp < 0) return MD_FAIL;
  for (k = 0;  k < n;  k++) {
    if (atom[k].m <= 0) return MD_FAIL;
  }
  return 0;
}


/* remove center of mass motion from initial velocities */
void remove_com_motion(Engine *eng)
{
  const MD_Atom *atom = (MD_Atom *) (eng->atom->buf);
  MD_Dvec *vel = (MD_Dvec *) (eng->vel->buf);
  const int32 n = eng->natoms;
  int32 k;
  MD_Dvec mv = { 0.0, 0.0, 0.0 };  /* accumulate net momentum */
  double mass = 0.0;               /* accumulate total mass */

  /* compute net momentum and total mass */
  for (k = 0;  k < n;  k++) {
    mv.x += atom[k].m * vel[k].x;
    mv.y += atom[k].m * vel[k].y;
    mv.z += atom[k].m * vel[k].z;
    mass += atom[k].m;
  }

  /* scale net momentum by total mass */
  mv.x /= mass;
  mv.y /= mass;
  mv.z /= mass;

  /* remove from atom velocities */
  for (k = 0;  k < n;  k++) {
    vel[k].x -= mv.x;
    vel[k].y -= mv.y;
    vel[k].z -= mv.z;
  }
}


void random_velocities(Engine *eng)
{
  double kbtemp = MD_BOLTZMAN * MD_ENERGY_CONST * eng->init_temp;
  double sqrt_kbtemp_div_mass;
#ifndef SIMEN_ALT_INITTEMP
  double rnum;
#endif
  const MD_Atom *atom = (MD_Atom *) (eng->atom->buf);
  MD_Dvec *vel = (MD_Dvec *) (eng->vel->buf);
  Random *r = &(eng->random);
  const int32 natoms = eng->natoms;
  int32 n, k;

  ASSERT(eng->init_temp >= 0.0);

  for (n = 0;  n < natoms;  n++) {
    ASSERT(atom[n].m > 0.0);
    sqrt_kbtemp_div_mass = sqrt(kbtemp / atom[n].m);

#ifndef SIMEN_ALT_INITTEMP
    /*
     * The following method and comments taken from NAMD WorkDistrib.C:
     *
     * //  The following comment was stolen from X-PLOR where
     * //  the following section of code was adapted from.
     *
     * //  This section generates a Gaussian random
     * //  deviate of 0.0 mean and standard deviation RFD for
     * //  each of the three spatial dimensions.
     * //  The algorithm is a "sum of uniform deviates algorithm"
     * //  which may be found in Abramowitz and Stegun,
     * //  "Handbook of Mathematical Functions", pg 952.
     */
    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].x = sqrt_kbtemp_div_mass * rnum;

    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].y = sqrt_kbtemp_div_mass * rnum;

    rnum = -6.0;
    for (k = 0;  k < 12;  k++) {
      rnum += random_uniform(r);
    }
    vel[n].z = sqrt_kbtemp_div_mass * rnum;
#else
    /*
     * Alternate method from NAMD Sequencer.C:
     */
    vel[n].x = sqrt_kbtemp_div_mass * random_gaussian(r);
    vel[n].y = sqrt_kbtemp_div_mass * random_gaussian(r);
    vel[n].z = sqrt_kbtemp_div_mass * random_gaussian(r);
#endif
  }
}
