/*
 * Copyright (C) 2004-2006 by David J. Hardy.  All rights reserved.
 *
 * bonded.c
 *
 * Routines to evaluate force for bonded atom interactions.
 */

#include <math.h>
#include "force/force.h"
#undef DEBUG_WATCH
#include "debug/debug.h"

#define DEBUG_VIRIAL
#undef DEBUG_VIRIAL

/*
 * evaluate selected bonds
 */
int force_compute_bonds(Force *fobj, double *u_bond, MD_Dvec f_bond[],
    double e_bond[], double virial[], const MD_Dvec pos[],
    const int32 bond_sel[], int32 bond_sel_len)
{
#ifdef DEBUG_VIRIAL
  double v[9] = { 0.0 };
#endif

  double e_sum = 0.0;  /* sum over all bond potentials */
  double e;            /* single bond interaction potential */

  const MD_Bond *bond = fobj->param->bond;           /* bond connections */
  const MD_BondPrm *bondprm = fobj->param->bondprm;  /* bond parameters */

  int32 i, j, k;

  for (k = 0;  k < bond_sel_len;  k++) {

    i = bond_sel[k];   /* evaluate ith bond interaction */
    j = bond[i].prm;   /* use jth bond parameters */

#ifndef DEBUG_VIRIAL
    e = force_compute_bond_interaction(f_bond, virial, pos,
        &(bond[i]), &(bondprm[j]));
#else
    e = force_compute_bond_interaction(f_bond, v, pos,
        &(bond[i]), &(bondprm[j]));
#endif

    e_bond[i] = e;     /* store bond interaction potential */
    e_sum += e;        /* accumulate potential */

  }

#ifdef DEBUG_VIRIAL
  virial[0] += v[0];
  virial[1] += v[1];
  virial[2] += v[2];
  virial[3] += v[3];
  virial[4] += v[4];
  virial[5] += v[5];
  virial[6] += v[6];
  virial[7] += v[7];
  virial[8] += v[8];
  printf("bond virial: %g %g %g  %g %g %g  %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
#endif

#ifdef DEBUG_VIRIAL
  printf("summed virial: %g %g %g  %g %g %g  %g %g %g\n",
      virial[0], virial[1], virial[2], virial[3], virial[4],
      virial[5], virial[6], virial[7], virial[8]);
#endif

  *u_bond = e_sum;
  return 0;
}


/*
 * evaluate selected angles
 */
int force_compute_angles(Force *fobj, double *u_angle, MD_Dvec f_angle[],
    double e_angle[], double virial[], const MD_Dvec pos[],
    const int32 angle_sel[], int32 angle_sel_len)
{
#ifdef DEBUG_VIRIAL
  double v[9] = { 0.0 };
#endif

  double e_sum = 0.0;  /* sum over all angle potentials */
  double e;            /* single angle interaction potential */

  const MD_Angle *angle = fobj->param->angle;           /* angle connections */
  const MD_AnglePrm *angleprm = fobj->param->angleprm;  /* angle parameters */

  int32 i, j, k;

  for (k = 0;  k < angle_sel_len;  k++) {

    i = angle_sel[k];  /* evaluate ith angle interaction */
    j = angle[i].prm;  /* use jth angle parameters */

#ifndef DEBUG_VIRIAL
    e = force_compute_angle_interaction(f_angle, virial, pos,
        &(angle[i]), &(angleprm[j]));
#else
    e = force_compute_angle_interaction(f_angle, v, pos,
        &(angle[i]), &(angleprm[j]));
#endif

    e_angle[i] = e;    /* store angle interaction potential */
    e_sum += e;        /* accumulate potential */

  }

#ifdef DEBUG_VIRIAL
  virial[0] += v[0];
  virial[1] += v[1];
  virial[2] += v[2];
  virial[3] += v[3];
  virial[4] += v[4];
  virial[5] += v[5];
  virial[6] += v[6];
  virial[7] += v[7];
  virial[8] += v[8];
  printf("angle virial: %g %g %g  %g %g %g  %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
#endif

#ifdef DEBUG_VIRIAL
  printf("summed virial: %g %g %g  %g %g %g  %g %g %g\n",
      virial[0], virial[1], virial[2], virial[3], virial[4],
      virial[5], virial[6], virial[7], virial[8]);
#endif

  *u_angle = e_sum;
  return 0;
}


/*
 * evaluate selected dihedrals
 */
int force_compute_dihedrals(Force *fobj, double *u_dihed, MD_Dvec f_dihed[],
    double e_dihed[], double virial[], const MD_Dvec pos[],
    const int32 dihed_sel[], int32 dihed_sel_len)
{
#ifdef DEBUG_VIRIAL
  double v[9] = { 0.0 };
#endif

  double e_sum = 0.0;  /* sum over all dihedral potentials */
  double e;            /* single dihedral interaction potential */

  const MD_Tors *dihed = fobj->param->dihed;            /* dihed connections */
  const MD_TorsPrm *dihedprm = fobj->param->dihedprm;   /* dihed parameters */

  int32 i, j, k;

  for (k = 0;  k < dihed_sel_len;  k++) {

    i = dihed_sel[k];  /* evaluate ith dihedral interaction */
    j = dihed[i].prm;  /* use jth dihedral parameters */

#ifndef DEBUG_VIRIAL
    e = force_compute_torsion_interaction(f_dihed, virial, pos,
        &(dihed[i]), &(dihedprm[j]));
#else
    e = force_compute_torsion_interaction(f_dihed, v, pos,
        &(dihed[i]), &(dihedprm[j]));
#endif

    e_dihed[i] = e;    /* store dihedral interaction potential */
    e_sum += e;        /* accumulate potential */

  }

#ifdef DEBUG_VIRIAL
  virial[0] += v[0];
  virial[1] += v[1];
  virial[2] += v[2];
  virial[3] += v[3];
  virial[4] += v[4];
  virial[5] += v[5];
  virial[6] += v[6];
  virial[7] += v[7];
  virial[8] += v[8];
  printf("dihed virial: %g %g %g  %g %g %g  %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
#endif

#ifdef DEBUG_VIRIAL
  printf("summed virial: %g %g %g  %g %g %g  %g %g %g\n",
      virial[0], virial[1], virial[2], virial[3], virial[4],
      virial[5], virial[6], virial[7], virial[8]);
#endif

  *u_dihed = e_sum;
  return 0;
}


/*
 * evaluate selected impropers
 */
int force_compute_impropers(Force *fobj, double *u_impr, MD_Dvec f_impr[],
    double e_impr[], double virial[], const MD_Dvec pos[],
    const int32 impr_sel[], int32 impr_sel_len)
{
#ifdef DEBUG_VIRIAL
  double v[9] = { 0.0 };
#endif

  double e_sum = 0.0;  /* sum over all improper potentials */
  double e;            /* single improper interaction potential */

  const MD_Tors *impr = fobj->param->impr;            /* impr connections */
  const MD_TorsPrm *imprprm = fobj->param->imprprm;   /* impr parameters */

  int32 i, j, k;

  for (k = 0;  k < impr_sel_len;  k++) {

    i = impr_sel[k];   /* evaluate ith improper interaction */
    j = impr[i].prm;   /* use jth improper parameters */

#ifndef DEBUG_VIRIAL
    e = force_compute_torsion_interaction(f_impr, virial, pos,
        &(impr[i]), &(imprprm[j]));
#else
    e = force_compute_torsion_interaction(f_impr, v, pos,
        &(impr[i]), &(imprprm[j]));
#endif

    e_impr[i] = e;     /* store improper interaction potential */
    e_sum += e;        /* accumulate potential */

  }

#ifdef DEBUG_VIRIAL
  virial[0] += v[0];
  virial[1] += v[1];
  virial[2] += v[2];
  virial[3] += v[3];
  virial[4] += v[4];
  virial[5] += v[5];
  virial[6] += v[6];
  virial[7] += v[7];
  virial[8] += v[8];
  printf("impr virial: %g %g %g  %g %g %g  %g %g %g\n",
      v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], v[8]);
#endif

#ifdef DEBUG_VIRIAL
  printf("summed virial: %g %g %g  %g %g %g  %g %g %g\n",
      virial[0], virial[1], virial[2], virial[3], virial[4],
      virial[5], virial[6], virial[7], virial[8]);
#endif

  *u_impr = e_sum;
  return 0;
}


/*
 * compute single bond interaction
 *
 * use positions from pos[]
 * update corresponding forces in f[]
 * return interaction potential
 */
double force_compute_bond_interaction(MD_Dvec f[], double virial[],
    const MD_Dvec pos[], const MD_Bond *bond, const MD_BondPrm *prm)
{
  MD_Dvec r12, f12;
  double r12len, dist, coef;
  double energy;
  const int32 i = bond->atom[0];
  const int32 j = bond->atom[1];

  r12.x = pos[j].x - pos[i].x;
  r12.y = pos[j].y - pos[i].y;
  r12.z = pos[j].z - pos[i].z;
  if (0==prm->r0) {
    coef = -2.0 * prm->k;
    energy = prm->k * (r12.x * r12.x + r12.y * r12.y + r12.z * r12.z);
  }
  else {
    r12len = sqrt(r12.x * r12.x + r12.y * r12.y + r12.z * r12.z);
    dist = r12len - prm->r0;
    coef = -2.0 * prm->k * dist / r12len;
    energy = prm->k * dist * dist;
  }
  f12.x = coef * r12.x;
  f12.y = coef * r12.y;
  f12.z = coef * r12.z;
  f[i].x -= f12.x;
  f[i].y -= f12.y;
  f[i].z -= f12.z;
  f[j].x += f12.x;
  f[j].y += f12.y;
  f[j].z += f12.z;
  virial[FORCE_VIRIAL_XX] += f12.x * r12.x;
  virial[FORCE_VIRIAL_XY] += f12.x * r12.y;
  virial[FORCE_VIRIAL_XZ] += f12.x * r12.z;
  virial[FORCE_VIRIAL_YX] += f12.y * r12.x;
  virial[FORCE_VIRIAL_YY] += f12.y * r12.y;
  virial[FORCE_VIRIAL_YZ] += f12.y * r12.z;
  virial[FORCE_VIRIAL_ZX] += f12.z * r12.x;
  virial[FORCE_VIRIAL_ZY] += f12.z * r12.y;
  virial[FORCE_VIRIAL_ZZ] += f12.z * r12.z;
  return energy;
}


/*
 * compute single angle interaction
 *
 * use positions from pos[]
 * update corresponding forces in f[]
 * return interaction potential
 */
double force_compute_angle_interaction(MD_Dvec f[], double virial[],
    const MD_Dvec pos[], const MD_Angle *angle, const MD_AnglePrm *prm)
{
  MD_Dvec r21, r23, f1, f2, f3, r13, f13;
  double inv_r21len, inv_r23len, cos_theta, sin_theta, delta_theta, coef;
  double energy, r13len, dist;
  const int32 i = angle->atom[0];
  const int32 j = angle->atom[1];
  const int32 k = angle->atom[2];

  r21.x = pos[i].x - pos[j].x;
  r21.y = pos[i].y - pos[j].y;
  r21.z = pos[i].z - pos[j].z;
  r23.x = pos[k].x - pos[j].x;
  r23.y = pos[k].y - pos[j].y;
  r23.z = pos[k].z - pos[j].z;
  inv_r21len = 1.0 / sqrt(r21.x * r21.x + r21.y * r21.y + r21.z * r21.z);
  inv_r23len = 1.0 / sqrt(r23.x * r23.x + r23.y * r23.y + r23.z * r23.z);
  cos_theta = (r21.x * r23.x + r21.y * r23.y + r21.z * r23.z)
              * inv_r21len * inv_r23len;
  /* cos(theta) should be in [-1,1] */
  /* however, we need to correct in case of roundoff error */
  if (cos_theta > 1.0)        cos_theta = 1.0;
  else if (cos_theta < -1.0)  cos_theta = -1.0;
  sin_theta = sqrt(1.0 - cos_theta * cos_theta);
  delta_theta = acos(cos_theta) - prm->theta0;
  coef = -2.0 * prm->k_theta * delta_theta / sin_theta;
  energy = prm->k_theta * delta_theta * delta_theta;
  f1.x = coef * (cos_theta * r21.x * inv_r21len - r23.x * inv_r23len)
         * inv_r21len;
  f1.y = coef * (cos_theta * r21.y * inv_r21len - r23.y * inv_r23len)
         * inv_r21len;
  f1.z = coef * (cos_theta * r21.z * inv_r21len - r23.z * inv_r23len)
         * inv_r21len;
  f3.x = coef * (cos_theta * r23.x * inv_r23len - r21.x * inv_r21len)
         * inv_r23len;
  f3.y = coef * (cos_theta * r23.y * inv_r23len - r21.y * inv_r21len)
         * inv_r23len;
  f3.z = coef * (cos_theta * r23.z * inv_r23len - r21.z * inv_r21len)
         * inv_r23len;
  f2.x = -(f1.x + f3.x);
  f2.y = -(f1.y + f3.y);
  f2.z = -(f1.z + f3.z);

  /* Urey-Bradley term effects only atoms 1 and 3 */
  if (prm->k_ub != 0.0) {
    r13.x = r23.x - r21.x;
    r13.y = r23.y - r21.y;
    r13.z = r23.z - r21.z;
    r13len = sqrt(r13.x * r13.x + r13.y * r13.y + r13.z * r13.z);
    dist = r13len - prm->r_ub;
    coef = -2.0 * prm->k_ub * dist / r13len;
    energy += prm->k_ub * dist * dist;
    f13.x = coef * r13.x;
    f13.y = coef * r13.y;
    f13.z = coef * r13.z;
    f1.x -= f13.x;
    f1.y -= f13.y;
    f1.z -= f13.z;
    f3.x += f13.x;
    f3.y += f13.y;
    f3.z += f13.z;
  }
  f[i].x += f1.x;
  f[i].y += f1.y;
  f[i].z += f1.z;
  f[j].x += f2.x;
  f[j].y += f2.y;
  f[j].z += f2.z;
  f[k].x += f3.x;
  f[k].y += f3.y;
  f[k].z += f3.z;
  virial[FORCE_VIRIAL_XX] += (f1.x * r21.x + f3.x * r23.x);
  virial[FORCE_VIRIAL_XY] += (f1.x * r21.y + f3.x * r23.y);
  virial[FORCE_VIRIAL_XZ] += (f1.x * r21.z + f3.x * r23.z);
  virial[FORCE_VIRIAL_YX] += (f1.y * r21.x + f3.y * r23.x);
  virial[FORCE_VIRIAL_YY] += (f1.y * r21.y + f3.y * r23.y);
  virial[FORCE_VIRIAL_YZ] += (f1.y * r21.z + f3.y * r23.z);
  virial[FORCE_VIRIAL_ZX] += (f1.z * r21.x + f3.z * r23.x);
  virial[FORCE_VIRIAL_ZY] += (f1.z * r21.y + f3.z * r23.y);
  virial[FORCE_VIRIAL_ZZ] += (f1.z * r21.z + f3.z * r23.z);
  return energy;
}


/*
 * compute single torsion interaction
 *
 * use positions from pos[]
 * update corresponding forces in f[]
 * return interaction potential
 *
 * (the following code adapted from NAMD2 ComputeDihedrals.C)
 */
double force_compute_torsion_interaction(MD_Dvec f[], double virial[],
    const MD_Dvec pos[], const MD_Tors *tors, const MD_TorsPrm *prm)
{
  MD_Dvec r12, r23, r34, A, B, C, dcosdA, dcosdB, dsindC, dsindB, f1, f2, f3;
  double rA, rB, rC, cos_phi, sin_phi, phi, delta, diff, k;
  double K = 0.0;    /* energy */
  double K1 = 0.0;   /* force */
  int32 j, n;
  const int32 atom1 = tors->atom[0];
  const int32 atom2 = tors->atom[1];
  const int32 atom3 = tors->atom[2];
  const int32 atom4 = tors->atom[3];

  r12.x = pos[atom1].x - pos[atom2].x;
  r12.y = pos[atom1].y - pos[atom2].y;
  r12.z = pos[atom1].z - pos[atom2].z;
  r23.x = pos[atom2].x - pos[atom3].x;
  r23.y = pos[atom2].y - pos[atom3].y;
  r23.z = pos[atom2].z - pos[atom3].z;
  r34.x = pos[atom3].x - pos[atom4].x;
  r34.y = pos[atom3].y - pos[atom4].y;
  r34.z = pos[atom3].z - pos[atom4].z;

  /* A = cross(r12, r23) */
  A.x = r12.y * r23.z - r12.z * r23.y;
  A.y = r12.z * r23.x - r12.x * r23.z;
  A.z = r12.x * r23.y - r12.y * r23.x;

  /* B = cross(r23, r34) */
  B.x = r23.y * r34.z - r23.z * r34.y;
  B.y = r23.z * r34.x - r23.x * r34.z;
  B.z = r23.x * r34.y - r23.y * r34.x;

  /* C = cross(r23, A) */
  C.x = r23.y * A.z - r23.z * A.y;
  C.y = r23.z * A.x - r23.x * A.z;
  C.z = r23.x * A.y - r23.y * A.x;

  rA = 1.0 / sqrt(A.x * A.x + A.y * A.y + A.z * A.z);
  rB = 1.0 / sqrt(B.x * B.x + B.y * B.y + B.z * B.z);
  rC = 1.0 / sqrt(C.x * C.x + C.y * C.y + C.z * C.z);

  /* normalize B */
  B.x *= rB;
  B.y *= rB;
  B.z *= rB;

  cos_phi = (A.x * B.x + A.y * B.y + A.z * B.z) * rA;
  sin_phi = (C.x * B.x + C.y * B.y + C.z * B.z) * rC;

  phi = -atan2(sin_phi, cos_phi);

  /* loop backwards through MD_TorsPrm list to pick up multiplicities */
  j = 1;
  do {
    j--;
    k  =    prm[j].k_tor;
    delta = prm[j].phi;
    n  =    prm[j].n;
    if (n) {
      /* as defined by CHARMM dihedral potential specification */
      K += k * (1.0 + cos(n * phi - delta));
      K1 += -n * k * sin(n * phi - delta);
    }
    else {
      diff = phi - delta;
      if      (diff < -M_PI)  diff += 2.0 * M_PI;
      else if (diff >  M_PI)  diff -= 2.0 * M_PI;
      K += k * diff * diff;
      K1 += 2.0 * k * diff;
    }
  } while (prm[j].mult > 1);
  /*
   * NOTE:  the "first" term (reached last) will have field mult==1
   *   and the first MD_TorsPrm array element must have field mult==1.
   */

  if (fabs(sin_phi) > 0.1) {
    /* use sine version to avoid 1/cos terms */

    /* normalize A */
    A.x *= rA;
    A.y *= rA;
    A.z *= rA;
    dcosdA.x = rA * (cos_phi * A.x - B.x);
    dcosdA.y = rA * (cos_phi * A.y - B.y);
    dcosdA.z = rA * (cos_phi * A.z - B.z);
    dcosdB.x = rB * (cos_phi * B.x - A.x);
    dcosdB.y = rB * (cos_phi * B.y - A.y);
    dcosdB.z = rB * (cos_phi * B.z - A.z);

    K1 /= sin_phi;
    f1.x = K1 * (r23.y * dcosdA.z - r23.z * dcosdA.y);
    f1.y = K1 * (r23.z * dcosdA.x - r23.x * dcosdA.z);
    f1.z = K1 * (r23.x * dcosdA.y - r23.y * dcosdA.x);

    f3.x = K1 * (r23.z * dcosdB.y - r23.y * dcosdB.z);
    f3.y = K1 * (r23.x * dcosdB.z - r23.z * dcosdB.x);
    f3.z = K1 * (r23.y * dcosdB.x - r23.x * dcosdB.y);

    f2.x = K1 * (r12.z * dcosdA.y - r12.y * dcosdA.z
                + r34.y * dcosdB.z - r34.z * dcosdB.y);
    f2.y = K1 * (r12.x * dcosdA.z - r12.z * dcosdA.x
                + r34.z * dcosdB.x - r34.x * dcosdB.z);
    f2.z = K1 * (r12.y * dcosdA.x - r12.x * dcosdA.y
                + r34.x * dcosdB.y - r34.y * dcosdB.x);
  }
  else {
    /* phi is too close to 0 or pi, use cos version to avoid 1/sin */

    /* normalize C */
    C.x *= rC;
    C.y *= rC;
    C.z *= rC;
    dsindC.x = rC * (sin_phi * C.x - B.x);
    dsindC.y = rC * (sin_phi * C.y - B.y);
    dsindC.z = rC * (sin_phi * C.z - B.z);
    dsindB.x = rB * (sin_phi * B.x - C.x);
    dsindB.y = rB * (sin_phi * B.y - C.y);
    dsindB.z = rB * (sin_phi * B.z - C.z);

    K1 /= -cos_phi;
    f1.x = K1 * ((r23.y * r23.y + r23.z * r23.z) * dsindC.x
                 - r23.x * r23.y * dsindC.y
                 - r23.x * r23.z * dsindC.z);
    f1.y = K1 * ((r23.z * r23.z + r23.x * r23.x) * dsindC.y
                 - r23.y * r23.z * dsindC.z
                 - r23.y * r23.x * dsindC.x);
    f1.z = K1 * ((r23.x * r23.x + r23.y * r23.y) * dsindC.z
                 - r23.z * r23.x * dsindC.x
                 - r23.z * r23.y * dsindC.y);

    /* f3 = K1 * cross(dsindB, r23) */
    f3.x = K1 * (dsindB.y * r23.z - dsindB.z * r23.y);
    f3.y = K1 * (dsindB.z * r23.x - dsindB.x * r23.z);
    f3.z = K1 * (dsindB.x * r23.y - dsindB.y * r23.x);

    f2.x = K1 * (-(r23.y * r12.y + r23.z * r12.z) * dsindC.x
                 + (2.0 * r23.x * r12.y - r12.x * r23.y) * dsindC.y
                 + (2.0 * r23.x * r12.z - r12.x * r23.z) * dsindC.z
                 + dsindB.z * r34.y - dsindB.y * r34.z);
    f2.y = K1 * (-(r23.z * r12.z + r23.x * r12.x) * dsindC.y
                 + (2.0 * r23.y * r12.z - r12.y * r23.z) * dsindC.z
                 + (2.0 * r23.y * r12.x - r12.y * r23.x) * dsindC.x
                 + dsindB.x * r34.z - dsindB.z * r34.x);
    f2.z = K1 * (-(r23.x * r12.x + r23.y * r12.y) * dsindC.z
                 + (2.0 * r23.z * r12.x - r12.z * r23.x) * dsindC.x
                 + (2.0 * r23.z * r12.y - r12.z * r23.y) * dsindC.y
                 + dsindB.y * r34.x - dsindB.x * r34.y);
  }

  f[atom1].x += f1.x;
  f[atom1].y += f1.y;
  f[atom1].z += f1.z;
  f[atom2].x += f2.x - f1.x;
  f[atom2].y += f2.y - f1.y;
  f[atom2].z += f2.z - f1.z;
  f[atom3].x += f3.x - f2.x;
  f[atom3].y += f3.y - f2.y;
  f[atom3].z += f3.z - f2.z;
  f[atom4].x += -f3.x;
  f[atom4].y += -f3.y;
  f[atom4].z += -f3.z;
  virial[FORCE_VIRIAL_XX] += (f1.x * r12.x + f2.x * r23.x + f3.x * r34.x);
  virial[FORCE_VIRIAL_XY] += (f1.x * r12.y + f2.x * r23.y + f3.x * r34.y);
  virial[FORCE_VIRIAL_XZ] += (f1.x * r12.z + f2.x * r23.z + f3.x * r34.z);
  virial[FORCE_VIRIAL_YX] += (f1.y * r12.x + f2.y * r23.x + f3.y * r34.x);
  virial[FORCE_VIRIAL_YY] += (f1.y * r12.y + f2.y * r23.y + f3.y * r34.y);
  virial[FORCE_VIRIAL_YZ] += (f1.y * r12.z + f2.y * r23.z + f3.y * r34.z);
  virial[FORCE_VIRIAL_ZX] += (f1.z * r12.x + f2.z * r23.x + f3.z * r34.x);
  virial[FORCE_VIRIAL_ZY] += (f1.z * r12.y + f2.z * r23.y + f3.z * r34.y);
  virial[FORCE_VIRIAL_ZZ] += (f1.z * r12.z + f2.z * r23.z + f3.z * r34.z);
  return K;
}
