/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * force.c
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "mdapi/mdengine.h"
#include "force/force.h"
#include "mgrid/mgrid.h"
#include "deven/engine.h"
#undef DEBUG_WATCH
#include "debug/debug.h"

#undef TEST_FSELECT


static void fix_linmo(Engine *e, MD_Dvec f[]);
static void conserve_linmo(Engine *e, MD_Dvec f[]);
static void wt_com_disp(Engine *e);


/*
 * called from step_compute()
 *
 * input:
 *   void *force_object -- from the step method
 *     (in this case it's actually MD_Front *)
 *   MD_Dvec *pos -- array of atomic positions
 *
 * output:
 *   double *pe -- total scalar potential energy
 *   MD_Dvec *f -- array of atomic forces
 *
 * side effects:
 *   internals of Force object
 *   force arrays, potentials, and virial in ForceResult are updated,
 *     some things implicitly through force_compute(), others explicitly
 *
 * post conditions:
 *   pe will be identical to ForceResult::u_total
 *   all energies, forces, and virial in ForceResult object will "agree"
 *   pe will be sum of other ForceResult energies, plus whatever other
 *     additions made here to the force evaluation
 */
int32 deven_force(void *vfront, double *pe, MD_Dvec *f, const MD_Dvec *pos)
{
  MD_Front *front = (MD_Front *) vfront;
  Engine *eng = (Engine *) MD_engine_data(front);
  Force *force = eng->force;
  ForceResult *f_res = &(eng->f_result);

/*** data for mgrid and pmetest ***/
  Mgrid *mgrid = &(eng->mgrid);
  MgridSystem *mgrid_system = &(eng->mgrid_system);
  Pmetest *pmetest = eng->pmetest;
  PmetestSystem *pt_system = &(eng->pt_system);
  const MD_Dvec *wrap;
  MD_Dvec *poswrap = (MD_Dvec *) (eng->poswrap->buf);
  MD_Dvec *f_elec = (MD_Dvec *) (eng->f_elec->buf);
  double *virial = f_res->virial;
#if 0
  MD_Dvec *force_elec = (MD_Dvec *) (eng->force_elec->buf);
#endif
/*** end data for mgrid and pmetest ***/

  int32 natoms = eng->natoms;
  int32 k;


#ifdef TEST_FSELECT
  ForceParam *fp = &(eng->f_param);
  ForceSelect *fs;

  printf("### setting force selection\n");
  if ((fs = force_select_create(fp,
          NULL, FORCE_SELECT_ALL,
          NULL, FORCE_SELECT_ALL,
          NULL, FORCE_SELECT_ALL,
          NULL, FORCE_SELECT_ALL,
          NULL, FORCE_SELECT_ALL,
          NULL, FORCE_SELECT_NONE,
          NULL, FORCE_SELECT_ALL)) == NULL) {
    fprintf(stderr, "NULL returned by force_select_create()\n");
    return MD_error(front, eng->err_force);
  }
  if (force_setup_selection(force, fs)) {
    fprintf(stderr, "nonzero returned by force_setup_selection()\n");
    return MD_error(front, eng->err_force);
  }
#endif

  TEXT("starting deven_force");

  /*
   * evaluate all forces using force library
   * (array f already given to f_res, updated implicitly)
   */
  if (force_compute(force, f_res, pos)) {
    return MD_error(front, eng->err_force);
  }

#if 0
  printf("force 0:  %g %g %g\n", force_elec[0].x, force_elec[0].y,
      force_elec[0].z);
#endif

  /*** mgrid - compute MSM long-range electrostatics ***/
  if (eng->ismgrid) {

    /* need to wrap positions back into periodic cell */
    wrap = force_get_poswrap(force);
    for (k = 0;  k < natoms;  k++) {
      poswrap[k].x = pos[k].x + wrap[k].x;
      poswrap[k].y = pos[k].y + wrap[k].y;
      poswrap[k].z = pos[k].z + wrap[k].z;
    }

    /* be paranoid, make sure that atoms are inside mgrid domain */
    if ((k = mgrid_system_validate(mgrid, mgrid_system)) < natoms) {
      printf("# mgrid system validation failed for atom k=%d\n", k);
      printf("# poswrap[%d] =  %g %g %g\n", k, poswrap[k].x, poswrap[k].y,
          poswrap[k].z);
      printf("# pos[%d] =  %g %g %g\n", k, pos[k].x, pos[k].y, pos[k].z);
      printf("# wrap[%d] =  %g %g %g\n", k, wrap[k].x, wrap[k].y, wrap[k].z);
      return MD_error(front, eng->err_force);
    }

    /*
     * compute mgrid
     * (mgrid_system has poswrap and sets f_elec)
     */
    if (mgrid_force(mgrid, mgrid_system)) {
      return MD_error(front, eng->err_force);
    }

    /* add energy contribution */
    f_res->u_elec += mgrid_system->u_elec;
    f_res->u_total += mgrid_system->u_elec;

    /*** no contribution to virial - mgrid does not compute it! ***/

    /* add contribution to total force */
    for (k = 0;  k < natoms;  k++) {
      f[k].x += f_elec[k].x;
      f[k].y += f_elec[k].y;
      f[k].z += f_elec[k].z;
    }

#if 0
    printf("mgrid 0:  %g %g %g\n", f_elec[0].x, f_elec[0].y, f_elec[0].z);
    printf("force 0:  %g %g %g\n", f_elec[0].x + force_elec[0].x,
        f_elec[0].y + force_elec[0].y, f_elec[0].z + force_elec[0].z);
#endif
  } /*** end mgrid ***/

  /*** pmetest - compute PME long-range electrostatics ***/
  else if (eng->ispmetest) {

    /* need to wrap positions back into periodic cell */
    wrap = force_get_poswrap(force);
    for (k = 0;  k < natoms;  k++) {
      poswrap[k].x = pos[k].x + wrap[k].x;
      poswrap[k].y = pos[k].y + wrap[k].y;
      poswrap[k].z = pos[k].z + wrap[k].z;
    }

    /*
     * compute pmetest
     * (pt_system has poswrap and sets f_elec)
     */
    TEXT("calling pmetest_compute");
    if (pmetest_compute(pmetest, pt_system)) {
      return MD_error(front, eng->err_force);
    }
    TEXT("done pmetest_compute");

    /* add energy contribution */
    f_res->u_elec += pt_system->u_elec;
    f_res->u_total += pt_system->u_elec;

    /* add contribution to virial */
    for (k = 0;  k < 9;  k++) {
      virial[k] += pt_system->virial_recip[k];
    }

    /* add contribution to total force */
    for (k = 0;  k < natoms;  k++) {
      f[k].x += f_elec[k].x;
      f[k].y += f_elec[k].y;
      f[k].z += f_elec[k].z;
    }
  } /*** end pmetest ***/

  /*** FIXED ATOMS - (kludge) reset their total force to zero ***/
  if (eng->nfixedatoms > 0) {
    int32 *atomindex = eng->fixedatom;
    int32 i;

    for (i = 0;  i < eng->nfixedatoms;  i++) {
      f[atomindex[i]].x = 0.0;
      f[atomindex[i]].y = 0.0;
      f[atomindex[i]].z = 0.0;
    }
  }

  /*** HARMONIC RESTRAINTS - adjust force and potential energy ***/
  if (eng->nconstraints > 0) {
    const DevenConstraint *cons = eng->constraint;
    int32 i, j;
    MD_Dvec refpos;
    double kk;
    int32 expo, index;
    const double scaling = eng->param.constraintScaling;
    MD_Dvec diff, rvec;
    double val, dot, r, r2;

    const MD_Dvec *a = force_get_cell_vectors(force);
    const MD_Dvec *b = force_get_row_transform(force);
    const int32 is_periodic = force_get_cell_boundary(force);

    const int32 is_select_components = eng->is_select_components;
    const int32 is_omit_xcoord = ! eng->is_select_xcoord;
    const int32 is_omit_ycoord = ! eng->is_select_ycoord;
    const int32 is_omit_zcoord = ! eng->is_select_zcoord;

    for (i = 0;  i < eng->nconstraints;  i++) {
      refpos = cons[i].refpos;
      kk = cons[i].k;
      expo = cons[i].expo;
      index = cons[i].index;

      diff.x = refpos.x - pos[index].x;
      diff.y = refpos.y - pos[index].y;
      diff.z = refpos.z - pos[index].z;
      rvec = diff;

      /*
       * find shortest distance between refpos and pos[index]
       * (i.e. shortest refpos-pos[index], where pos[index]
       * is wrapped to closest periodic image)
       */
      if (is_periodic & FORCE_X_PERIODIC) {
        dot = b[0].x * diff.x + b[0].y * diff.y + b[0].z * diff.z;
        val = floor(dot + 0.5);
        rvec.x -= a[0].x * val;
        rvec.y -= a[0].y * val;
        rvec.z -= a[0].z * val;
      }
      if (is_periodic & FORCE_Y_PERIODIC) {
        dot = b[1].x * diff.x + b[1].y * diff.y + b[1].z * diff.z;
        val = floor(dot + 0.5);
        rvec.x -= a[1].x * val;
        rvec.y -= a[1].y * val;
        rvec.z -= a[1].z * val;
      }
      if (is_periodic & FORCE_Z_PERIODIC) {
        dot = b[2].x * diff.x + b[2].y * diff.y + b[2].z * diff.z;
        val = floor(dot + 0.5);
        rvec.x -= a[2].x * val;
        rvec.y -= a[2].y * val;
        rvec.z -= a[2].z * val;
      }

      if (is_select_components) {
        /* turn off harmonic restraint along selected Cartesian components */
        if (is_omit_xcoord) rvec.x = 0.0;
        if (is_omit_ycoord) rvec.y = 0.0;
        if (is_omit_zcoord) rvec.z = 0.0;
      }

      r2 = rvec.x * rvec.x + rvec.y * rvec.y + rvec.z * rvec.z;
      r = sqrt(r2);

      /*
       * calculate energy and force only for nonzero distance
       * (removeable singularity at zero)
       */
      if (r > 0.0) {
        val = kk * scaling;
        for (j = 0;  j < expo;  j++) {
          val *= r;
        }
#ifdef DEBUG_WATCH
        fprintf(stderr, "atom id = %d\n", index);
        fprintf(stderr, "energy value = %.15g\n", val);
#endif
        /* don't have a "misc" energy, just add to total energy */
        f_res->u_total += val;

        val *= expo;
        val /= r2;
        rvec.x *= val;
        rvec.y *= val;
        rvec.z *= val;
        f[index].x += rvec.x;
        f[index].y += rvec.y;
        f[index].z += rvec.z;
      }

      /*** does this also contribute to the virial? ***/
    }
  } /* end HARMONIC RESTRAINTS */

  /*
   * compute force correction for conserving linear momentum
   * (don't use this with fixed atoms or harmonic restraints)
   */
  if (eng->forceopts) {
    switch (eng->forceopts) {
      case ENGINE_FIX_LINMO:
        fix_linmo(eng, f);
        break;
      case ENGINE_CONS_LINMO:
        conserve_linmo(eng, f);
        break;
    }
    wt_com_disp(eng);
  }

  /* return total potential energy */
  *pe = f_res->u_total;

#ifdef TEST_FSELECT
  force_select_destroy(fs);
  if (force_setup_selection(force, NULL)) {
    fprintf(stderr, "nonzero returned by force_setup_selection()\n");
    return MD_error(front, eng->err_force);
  }
#endif

  return 0;
}


/*
 * standard approach to conserve linear momentum when
 * using a potential which is not translationally invariant,
 * however it causes energy drift in long simulations
 */
void fix_linmo(Engine *e, MD_Dvec f[])
{
  MD_Dvec f_sum = { 0.0, 0.0, 0.0 };
  const double inv_natoms = e->inv_natoms;
  const int32 natoms = e->natoms;
  int32 i;

  TEXT("fix");
  FLT(inv_natoms);
  for (i = 0;  i < natoms;  i++) {
    f_sum.x += f[i].x;
    f_sum.y += f[i].y;
    f_sum.z += f[i].z;
  }
  VEC(f_sum);
  f_sum.x *= inv_natoms;
  f_sum.y *= inv_natoms;
  f_sum.z *= inv_natoms;
  VEC(f_sum);
  e->result.fcorr = f_sum;

  for (i = 0;  i < natoms;  i++) {
    f[i].x -= f_sum.x;
    f[i].y -= f_sum.y;
    f[i].z -= f_sum.z;
  }
}


/*
 * conserve linear momentum without energy drift?
 */
void conserve_linmo(Engine *e, MD_Dvec f[])
{
  MD_Dvec f_sum = { 0.0, 0.0, 0.0 };
  const double inv_natoms = e->inv_natoms;
  const double *scaled_mass = e->scaled_mass;
  const int32 natoms = e->natoms;
  int32 i;

  /* dave's attempt */
  for (i = 0;  i < natoms;  i++) {
    f_sum.x += f[i].x;
    f_sum.y += f[i].y;
    f_sum.z += f[i].z;
  }

  /* use average force correction to make comparable to other method */
  e->result.fcorr.x = inv_natoms * f_sum.x;
  e->result.fcorr.y = inv_natoms * f_sum.y;
  e->result.fcorr.z = inv_natoms * f_sum.z;

  for (i = 0;  i < natoms;  i++) {
    f[i].x -= scaled_mass[i] * f_sum.x;
    f[i].y -= scaled_mass[i] * f_sum.y;
    f[i].z -= scaled_mass[i] * f_sum.z;
  }

}


/*
 * weighted center-of-mass displacement
 *
 * (this should be a conserved quantity)
 */
void wt_com_disp(Engine *e)
{
  MD_Dvec g = { 0.0, 0.0, 0.0 };
  const double *scaled_mass = e->scaled_mass;
  const MD_Dvec *r = (const MD_Dvec *)(e->pos->buf);
  const MD_Dvec *r_init = e->init_pos;
  const int32 natoms = e->natoms;
  int32 i;

  for (i = 0;  i < natoms;  i++) {
    g.x += scaled_mass[i] * (r[i].x - r_init[i].x);
    g.y += scaled_mass[i] * (r[i].y - r_init[i].y);
    g.z += scaled_mass[i] * (r[i].z - r_init[i].z);
  }
  e->result.wcomd = g;
}
