/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * bonded.c
 *
 * Routines to evaluate force for bonded atom interactions.
 */

#include <math.h>
#include "force/force.h"
#include "debug/debug.h"


/*
 * prototypes for internal routines that compute a single interaction
 */

static double compute_bond(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Bond *, const MD_BondPrm *);

static double compute_angle(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Angle *, const MD_AnglePrm *);

static double compute_tors(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Tors *, const MD_TorsPrm *);


/*
 * setup routine called externally
 */
int force_setup_bonded(Force *f)
{
  return 0;  /* nothing to do! */
}


/*
 * computation routine called externally
 */
int force_compute_bonded(Force *force, const MD_Dvec *pos)
{
  ForceParam *p = force->param;
  ForceEnergy *e = force->energy;
  ForceResult *r = force->result;

  /* bonds */
  const MD_BondPrm *bondprm = p->bondprm;
  const MD_Bond *bond = p->bond;
  const int32 bond_len  = p->bond_len;
  MD_Dvec *f_bond = (r->f_bond ? r->f_bond : r->f);

  /* angles */
  const MD_AnglePrm *angleprm = p->angleprm;
  const MD_Angle *angle = p->angle;
  const int32 angle_len  = p->angle_len;
  MD_Dvec *f_angle = (r->f_angle ? r->f_angle : r->f);

  /* dihedrals */
  const MD_TorsPrm *dihedprm = p->dihedprm;
  const MD_Tors *dihed = p->dihed;
  const int32 dihed_len  = p->dihed_len;
  MD_Dvec *f_dihed = (r->f_dihed ? r->f_dihed : r->f);

  /* impropers */
  const MD_TorsPrm *imprprm = p->imprprm;
  const MD_Tors *impr = p->impr;
  const int32 impr_len  = p->impr_len;
  MD_Dvec *f_impr = (r->f_impr ? r->f_impr : r->f);

  /* total force */
  const int32 atom_len = p->atom_len;
  MD_Dvec *f = r->f;

  /* accumulate energy, list counter */
  double pe;
  int32 k;
  const int32 flags = p->flags;

  ASSERT(f != NULL);

  if (flags & FORCE_BOND) {
    /* compute bond springs */
    pe = 0.0;
    for (k = 0;  k < bond_len;  k++) {
      pe += compute_bond(f_bond, pos, bond + k, bondprm + bond[k].prm);
    }
    e->bond = pe;
    if (f_bond != f) {
      /* accumulate bond force into total */
      for (k = 0;  k < atom_len;  k++) {
        f[k].x += f_bond[k].x;
        f[k].y += f_bond[k].y;
        f[k].z += f_bond[k].z;
      }
    }
  }

  if (flags & FORCE_ANGLE) {
    /* compute bond angles */
    pe = 0.0;
    for (k = 0;  k < angle_len;  k++) {
      pe += compute_angle(f_angle, pos, angle + k, angleprm + angle[k].prm);
    }
    e->angle = pe;
    if (f_angle != f) {
      /* accumulate angle force into total */
      for (k = 0;  k < atom_len;  k++) {
        f[k].x += f_angle[k].x;
        f[k].y += f_angle[k].y;
        f[k].z += f_angle[k].z;
      }
    }
  }

  if (flags & FORCE_DIHED) {
    /* compute dihedral torsion angles */
    pe = 0.0;
    for (k = 0;  k < dihed_len;  k++) {
      pe += compute_tors(f_dihed, pos, dihed + k, dihedprm + dihed[k].prm);
    }
    e->dihed = pe;
    if (f_dihed != f) {
      /* accumulate dihedral force into total */
      for (k = 0;  k < atom_len;  k++) {
        f[k].x += f_dihed[k].x;
        f[k].y += f_dihed[k].y;
        f[k].z += f_dihed[k].z;
      }
    }
  }

  if (flags & FORCE_IMPR) {
    /* compute improper torsion angles */
    pe = 0.0;
    for (k = 0;  k < impr_len;  k++) {
      pe += compute_tors(f_impr, pos, impr + k, imprprm + impr[k].prm);
    }
    e->impr = pe;
    if (f_impr != f) {
      /* accumulate improper force into total */
      for (k = 0;  k < atom_len;  k++) {
        f[k].x += f_impr[k].x;
        f[k].y += f_impr[k].y;
        f[k].z += f_impr[k].z;
      }
    }
  }

  return 0;
}


double compute_bond(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Bond *bond, const MD_BondPrm *prm)
{
  MD_Dvec r12, f12;
  double r12len, dist, coef;
  double energy;
  const int32 i = bond->atom[0];
  const int32 j = bond->atom[1];

  r12.x = pos[j].x - pos[i].x;
  r12.y = pos[j].y - pos[i].y;
  r12.z = pos[j].z - pos[i].z;
  r12len = sqrt(r12.x * r12.x + r12.y * r12.y + r12.z * r12.z);
  dist = r12len - prm->r0;
  coef = -2.0 * prm->k * dist / r12len;
  energy = prm->k * dist * dist;
  f12.x = coef * r12.x;
  f12.y = coef * r12.y;
  f12.z = coef * r12.z;
  f[i].x -= f12.x;
  f[i].y -= f12.y;
  f[i].z -= f12.z;
  f[j].x += f12.x;
  f[j].y += f12.y;
  f[j].z += f12.z;
  return energy;
}


double compute_angle(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Angle *angle, const MD_AnglePrm *prm)
{
  MD_Dvec r21, r23, f1, f2, f3, r13, f13;
  double inv_r21len, inv_r23len, cos_theta, sin_theta, delta_theta, coef;
  double energy, r13len, dist;
  const int32 i = angle->atom[0];
  const int32 j = angle->atom[1];
  const int32 k = angle->atom[2];

  r21.x = pos[i].x - pos[j].x;
  r21.y = pos[i].y - pos[j].y;
  r21.z = pos[i].z - pos[j].z;
  r23.x = pos[k].x - pos[j].x;
  r23.y = pos[k].y - pos[j].y;
  r23.z = pos[k].z - pos[j].z;
  inv_r21len = 1.0 / sqrt(r21.x * r21.x + r21.y * r21.y + r21.z * r21.z);
  inv_r23len = 1.0 / sqrt(r23.x * r23.x + r23.y * r23.y + r23.z * r23.z);
  cos_theta = (r21.x * r23.x + r21.y * r23.y + r21.z * r23.z)
              * inv_r21len * inv_r23len;
  /* cos(theta) should be in [-1,1] */
  /* however, we need to correct in case of roundoff error */
  if (cos_theta > 1.0)        cos_theta = 1.0;
  else if (cos_theta < -1.0)  cos_theta = -1.0;
  sin_theta = sqrt(1.0 - cos_theta * cos_theta);
  delta_theta = acos(cos_theta) - prm->theta0;
  coef = -2.0 * prm->k_theta * delta_theta / sin_theta;
  energy = prm->k_theta * delta_theta * delta_theta;
  f1.x = coef * (cos_theta * r21.x * inv_r21len - r23.x * inv_r23len)
         * inv_r21len;
  f1.y = coef * (cos_theta * r21.y * inv_r21len - r23.y * inv_r23len)
         * inv_r21len;
  f1.z = coef * (cos_theta * r21.z * inv_r21len - r23.z * inv_r23len)
         * inv_r21len;
  f3.x = coef * (cos_theta * r23.x * inv_r23len - r21.x * inv_r21len)
         * inv_r23len;
  f3.y = coef * (cos_theta * r23.y * inv_r23len - r21.y * inv_r21len)
         * inv_r23len;
  f3.z = coef * (cos_theta * r23.z * inv_r23len - r21.z * inv_r21len)
         * inv_r23len;
  f2.x = -(f1.x + f3.x);
  f2.y = -(f1.y + f3.y);
  f2.z = -(f1.z + f3.z);

  /* Urey-Bradley term effects only atoms 1 and 3 */
  if (prm->k_ub != 0.0) {
    r13.x = r23.x - r21.x;
    r13.y = r23.y - r21.y;
    r13.z = r23.z - r21.z;
    r13len = sqrt(r13.x * r13.x + r13.y * r13.y + r13.z * r13.z);
    dist = r13len - prm->r_ub;
    coef = -2.0 * prm->k_ub * dist / r13len;
    energy += prm->k_ub * dist * dist;
    f13.x = coef * r13.x;
    f13.y = coef * r13.y;
    f13.z = coef * r13.z;
    f1.x -= f13.x;
    f1.y -= f13.y;
    f1.z -= f13.z;
    f3.x += f13.x;
    f3.y += f13.y;
    f3.z += f13.z;
  }
  f[i].x += f1.x;
  f[i].y += f1.y;
  f[i].z += f1.z;
  f[j].x += f2.x;
  f[j].y += f2.y;
  f[j].z += f2.z;
  f[k].x += f3.x;
  f[k].y += f3.y;
  f[k].z += f3.z;
  return energy;
}


/*
 * The following code adapted from NAMD2 ComputeDihedrals.C
 */
double compute_tors(MD_Dvec *f, const MD_Dvec *pos,
    const MD_Tors *tors, const MD_TorsPrm *prm)
{
  MD_Dvec r12, r23, r34, A, B, C, dcosdA, dcosdB, dsindC, dsindB, f1, f2, f3;
  double rA, rB, rC, cos_phi, sin_phi, phi, delta, diff, k, K, K1;
  double energy = 0.0;
  int32 j, mult, n, not_too_small;
  const int32 atom1 = tors->atom[0];
  const int32 atom2 = tors->atom[1];
  const int32 atom3 = tors->atom[2];
  const int32 atom4 = tors->atom[3];

  r12.x = pos[atom1].x - pos[atom2].x;
  r12.y = pos[atom1].y - pos[atom2].y;
  r12.z = pos[atom1].z - pos[atom2].z;
  r23.x = pos[atom2].x - pos[atom3].x;
  r23.y = pos[atom2].y - pos[atom3].y;
  r23.z = pos[atom2].z - pos[atom3].z;
  r34.x = pos[atom3].x - pos[atom4].x;
  r34.y = pos[atom3].y - pos[atom4].y;
  r34.z = pos[atom3].z - pos[atom4].z;

  /* A = cross(r12, r23) */
  A.x = r12.y * r23.z - r12.z * r23.y;
  A.y = r12.z * r23.x - r12.x * r23.z;
  A.z = r12.x * r23.y - r12.y * r23.x;

  /* B = cross(r23, r34) */
  B.x = r23.y * r34.z - r23.z * r34.y;
  B.y = r23.z * r34.x - r23.x * r34.z;
  B.z = r23.x * r34.y - r23.y * r34.x;

  /* C = cross(r23, A) */
  C.x = r23.y * A.z - r23.z * A.y;
  C.y = r23.z * A.x - r23.x * A.z;
  C.z = r23.x * A.y - r23.y * A.x;

  rA = 1.0 / sqrt(A.x * A.x + A.y * A.y + A.z * A.z);
  rB = 1.0 / sqrt(B.x * B.x + B.y * B.y + B.z * B.z);
  rC = 1.0 / sqrt(C.x * C.x + C.y * C.y + C.z * C.z);

  /* normalize B */
  B.x *= rB;
  B.y *= rB;
  B.z *= rB;

  cos_phi = (A.x * B.x + A.y * B.y + A.z * B.z) * rA;
  sin_phi = (C.x * B.x + C.y * B.y + C.z * B.z) * rC;

  phi = -atan2(sin_phi, cos_phi);

  not_too_small = (fabs(sin_phi) > 0.1);
  if (not_too_small) {
    /* normalize A */
    A.x *= rA;
    A.y *= rA;
    A.z *= rA;
    dcosdA.x = rA * (cos_phi * A.x - B.x);
    dcosdA.y = rA * (cos_phi * A.y - B.y);
    dcosdA.z = rA * (cos_phi * A.z - B.z);
    dcosdB.x = rB * (cos_phi * B.x - A.x);
    dcosdB.y = rB * (cos_phi * B.y - A.y);
    dcosdB.z = rB * (cos_phi * B.z - A.z);
  }
  else {
    /* normalize C */
    C.x *= rC;
    C.y *= rC;
    C.z *= rC;
    dsindC.x = rC * (sin_phi * C.x - B.x);
    dsindC.y = rC * (sin_phi * C.y - B.y);
    dsindC.z = rC * (sin_phi * C.z - B.z);
    dsindB.x = rB * (sin_phi * B.x - C.x);
    dsindB.y = rB * (sin_phi * B.y - C.y);
    dsindB.z = rB * (sin_phi * B.z - C.z);
  }

  /* clear f1, f2, f3 before accumulating forces */
  f1.x = f1.y = f1.z = 0.0;
  f2.x = f2.y = f2.z = 0.0;
  f3.x = f3.y = f3.z = 0.0;

  mult = prm->mult;
  for (j = 0;  j < mult;  j++) {
    ASSERT(prm[j].mult + j == mult);
    k  =    prm[j].k_tor;
    delta = prm[j].phi;
    n  =    prm[j].n;
    if (n) {
      K = k * (1.0 + cos(n * phi + delta));
      K1 = -n * k * sin(n * phi + delta);
    }
    else {
      diff = phi - delta;
      if      (diff < -M_PI)  diff += 2.0 * M_PI;
      else if (diff >  M_PI)  diff -= 2.0 * M_PI;
      K = k * diff * diff;
      K1 = 2.0 * k * diff;
    }
    energy += K;

    /* forces */
    if (not_too_small) {
      K1 /= sin_phi;
      f1.x += K1 * (r23.y * dcosdA.z - r23.z * dcosdA.y);
      f1.y += K1 * (r23.z * dcosdA.x - r23.x * dcosdA.z);
      f1.z += K1 * (r23.x * dcosdA.y - r23.y * dcosdA.x);

      f3.x += K1 * (r23.z * dcosdB.y - r23.y * dcosdB.z);
      f3.y += K1 * (r23.x * dcosdB.z - r23.z * dcosdB.x);
      f3.z += K1 * (r23.y * dcosdB.x - r23.x * dcosdB.y);

      f2.x += K1 * (r12.z * dcosdA.y - r12.y * dcosdA.z
                  + r34.y * dcosdB.z - r34.z * dcosdB.y);
      f2.y += K1 * (r12.x * dcosdA.z - r12.z * dcosdA.x
                  + r34.z * dcosdB.x - r34.x * dcosdB.z);
      f2.z += K1 * (r12.y * dcosdA.x - r12.x * dcosdA.y
                  + r34.x * dcosdB.y - r34.y * dcosdB.x);
    }
    else {
      /* phi is too close to 0 or pi, use cos version to avoid 1/sin */
      K1 /= -cos_phi;
      f1.x += K1 * ((r23.y * r23.y + r23.z * r23.z) * dsindC.x
                   - r23.x * r23.y * dsindC.y
                   - r23.x * r23.z * dsindC.z);
      f1.y += K1 * ((r23.z * r23.z + r23.x * r23.x) * dsindC.y
                   - r23.y * r23.z * dsindC.z
                   - r23.y * r23.x * dsindC.x);
      f1.z += K1 * ((r23.x * r23.x + r23.y * r23.y) * dsindC.z
                   - r23.z * r23.x * dsindC.x
                   - r23.z * r23.y * dsindC.y);

      /* f3 += K1 * cross(dsindB, r23) */
      f3.x += K1 * (dsindB.y * r23.z - dsindB.z * r23.y);
      f3.y += K1 * (dsindB.z * r23.x - dsindB.x * r23.z);
      f3.z += K1 * (dsindB.x * r23.y - dsindB.y * r23.x);

      f2.x += K1 * (-(r23.y * r12.y + r23.z * r12.z) * dsindC.x
                   + (2.0 * r23.x * r12.y - r12.x * r23.y) * dsindC.y
                   + (2.0 * r23.x * r12.z - r12.x * r23.z) * dsindC.z
                   + dsindB.z * r34.y - dsindB.y * r34.z);
      f2.y += K1 * (-(r23.z * r12.z + r23.x * r12.x) * dsindC.y
                   + (2.0 * r23.y * r12.z - r12.y * r23.z) * dsindC.z
                   + (2.0 * r23.y * r12.x - r12.y * r23.x) * dsindC.x
                   + dsindB.x * r34.z - dsindB.z * r34.x);
      f2.z += K1 * (-(r23.x * r12.x + r23.y * r12.y) * dsindC.z
                   + (2.0 * r23.z * r12.x - r12.z * r23.x) * dsindC.x
                   + (2.0 * r23.z * r12.y - r12.z * r23.y) * dsindC.y
                   + dsindB.y * r34.x - dsindB.x * r34.y);
    }
  }  /* end loop over multiplicity */
  f[atom1].x += f1.x;
  f[atom1].y += f1.y;
  f[atom1].z += f1.z;
  f[atom2].x += f2.x - f1.x;
  f[atom2].y += f2.y - f1.y;
  f[atom2].z += f2.z - f1.z;
  f[atom3].x += f3.x - f2.x;
  f[atom3].y += f3.y - f2.y;
  f[atom3].z += f3.z - f2.z;
  f[atom4].x += -f3.x;
  f[atom4].y += -f3.y;
  f[atom4].z += -f3.z;
  return energy;
}
