/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */


#include <stdio.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#include <stdlib.h>
#include "mdtypes.h"
#include "constant.h"
#include "utilities.h"
#include "rattle.h"

/*
 * rigid water molecule:
 *
 *                  O                O: Oxygen atom
 *                 / \               H: Hydrogen atom
 *             d1 /   \ d1           angle(H1OH2) = theta.
 *               /     \
 *             H1 ----- H2
 *                  d2
 */



void rattle_init(struct Rattle_Water *sw, 
		 const MD_Double bondOH, const MD_Double angle_HOH, 
		 const MD_Double massO, const MD_Double massH,
		 const MD_Int natoms)
{
  assert(NULL != sw);
  sw->bondOH = bondOH;
  sw->angle_HOH = angle_HOH;
  sw->bondHH = 2.0 * sw->bondOH * sin(sw->angle_HOH * 0.5);
  sw->massO = massO;
  sw->massH = massH;
  sw->natoms = natoms;
  sw->old_pos = malloc(sw->natoms * sizeof(MD_Dvec));
  assert(NULL != sw->old_pos);
  sw->errTol2 = 1e-14;     /* 1e-14 is quite accrate */
#if 0
  sw->bond_errTol = 1e-8;  /* is this accurate to avoid energy drift ? */
#else
  printf("use less accurate rattle !\n");
  sw->bond_errTol = 1e-5;  /* is this accurate to avoid energy drift ? */
#endif
  sw->vel_errTol = 1e-6;  
  sw->maxiter = 1000;

  printf("Use rattle method to constraint water molecule\n");
  printf("  bond error tolerance = %g, velocity error tolerance = %g\n",
	 sw->bond_errTol, sw->vel_errTol);
  printf("  control parameter: errTol2=%g, maxiter=%d\n",
	 sw->errTol2, sw->maxiter);
}

void rattle_destroy(struct Rattle_Water *sw) 
{
  if (NULL != sw->old_pos) free(sw->old_pos);
  memset(sw, 0, sizeof(struct Rattle_Water));
}

void rattle_prepare(struct Rattle_Water *sw, const MD_Dvec *pos)
{
  memcpy(sw->old_pos, pos, sw->natoms * sizeof(MD_Dvec));
}

/* 
 * use rattle method to do constraint dynamics. 
 * input:  
 *        rold:  positon array (of a water molecule) at t
 *        rnew:  positon array at t + delta_t
 * output:  
 *        rnew: modified positon at t + delta_t
 *              so that the constraint is satisfied.  
 * be careful about the periodic boundary condition.  
*/

/* something might be wrong in math.h for abs function */
MD_Errcode rattle1(struct Rattle_Water *sw, const MD_Int iatom,
		   MD_Dvec *rnew, MD_Dvec *vel,  const MD_Double dt)
{
  /* rigid water model parameters. d1 = d(O, H1) = d(O, H2),
   * d2 = d(H1, H2).  theta = angle(OH1, OH2). */
  const MD_Dvec *rold = sw->old_pos + iatom;
  const MD_Double errTol = sw->bond_errTol; /* bond length */
  const MD_Double errTol2 =  sw->errTol2;
  const MD_Double d1 = sw->bondOH;   /* angstron */ 
  const MD_Double theta = sw->angle_HOH;  /* 109.5 degree */
  const MD_Double phi = (Pi - theta) * 0.5; 
  const MD_Double d2 = sw->bondHH;
  const MD_Double d1sqr = d1 * d1;
  const MD_Double d2sqr = d2 * d2;
  const MD_Double d1d2cosphi = d1 * d2 * cos(phi);
  const MD_Double d1d1costheta = d1 * d1 * cos(theta);
  const MD_Double mo = sw->massO;
  const MD_Double mh = sw->massH;
  const MD_Double po = 2.0 * dt * dt / mo;
  const MD_Double ph = 2.0 * dt * dt / mh;
  const MD_Double poph = po+ph;
  const MD_Double posqr = po*po;
  const MD_Double phsqr = ph*ph;
  const MD_Double pophsqr = poph * poph;
  const MD_Int maxiter = sw->maxiter;

  const MD_Dvec *r0 = rold;
  const MD_Dvec *r1 = rold + 1;
  const MD_Dvec *r2 = rold + 2;
  MD_Dvec *ro = rnew;
  MD_Dvec *rh1 = rnew + 1;
  MD_Dvec *rh2 = rnew + 2;
  MD_Dvec roh1, roh2, rh1h2;
  MD_Dvec r01, r02, r12;
  MD_Dvec *velo, *velh1, *velh2;
  MD_Dvec deltav;

  MD_Double c11, c12, c13, c21, c22, c23, c31, c32, c33;
  MD_Double lmd01 = 0.0, lmd02 = 0.0, lmd12 = 0.0;
  MD_Double lmd01o, lmd02o, lmd12o;
  MD_Double rhs1, rhs2, rhs3, inv_det;
  MD_Double err01, err02, err12;
  MD_Double dlmd01, dlmd02, dlmd12;
  MD_Double lmd01sqr, lmd02sqr, lmd12sqr, tmp1, tmp2;
  MD_Int iter;
  MD_Double doh1sqr, doh2sqr, dh1h2sqr;

#ifdef DEBUG_RATTLE
  printf("po=%20.15f, ph=%20.15f\n", po, ph);
  printf("old:\n");
  printf(" O: (%f,%f, %f)\n", r0->x, r0->y, r0->z);
  printf("H1: (%f,%f, %f)\n", r1->x, r1->y, r1->z);
  printf("H2: (%f,%f, %f)\n", r2->x, r2->y, r2->z);
  printf("before:\n");
  printf(" O: (%f,%f, %f)\n", ro->x, ro->y, ro->z);
  printf("H1: (%f,%f, %f)\n", rh1->x, rh1->y, rh1->z);
  printf("H2: (%f,%f, %f)\n", rh2->x, rh2->y, rh2->z);
#endif

  MD_pvec_substract(ro, rh1, &roh1);
  MD_pvec_substract(ro, rh2, &roh2);
  MD_pvec_substract(rh1, rh2, &rh1h2);

  MD_pvec_substract(r0, r1, &r01);
  MD_pvec_substract(r0, r2, &r02);
  MD_pvec_substract(r1, r2, &r12);

  c11 =   poph * MD_vec_dot(roh1,  r01);
  c12 =     po * MD_vec_dot(roh1,  r02);
  c13 =   - ph * MD_vec_dot(roh1,  r12);
  c21 =     po * MD_vec_dot(roh2,  r01);
  c22 =   poph * MD_vec_dot(roh2,  r02);
  c23 =     ph * MD_vec_dot(roh2,  r12);
  c31 =   - ph * MD_vec_dot(rh1h2, r01);
  c32 =   + ph * MD_vec_dot(rh1h2, r02);
  c33 = 2.0*ph * MD_vec_dot(rh1h2, r12);

  inv_det = 1.0 /  (c11*c22*c33 + c21*c32*c13 + c31*c12*c23
		  - c13*c22*c31 - c23*c32*c11 - c33*c12*c21);

#ifdef DEBUG_RATTLE
  {
    MD_Double t1, t2, t3, t, t_inv;
    MD_Double d11, d12, d13, d21, d22, d23, d31, d32, d33;
    MD_Double cond;
    /* compute 1-norm of matrix {c_ij}. */
    t1 = fabs(c11) + fabs(c12) + fabs(c13);
    t2 = fabs(c21) + fabs(c22) + fabs(c23);
    t3 = fabs(c31) + fabs(c32) + fabs(c33);
    t = t1 > t2 ? t1:t2;
    t = t > t3 ? t:t3;
    /* compute {c_ij}^{-1} */
    d11 =  (c22*c33 - c23*c32) * inv_det;
    d12 = -(c21*c33 - c23*c31) * inv_det;
    d13 =  (c21*c32 - c22*c31) * inv_det;
    d21 = -(c12*c33 - c13*c32) * inv_det;
    d22 =  (c11*c33 - c13*c31) * inv_det;
    d23 = -(c11*c32 - c12*c31) * inv_det;
    d31 =  (c12*c23 - c22*c13) * inv_det;
    d32 = -(c11*c23 - c13*c21) * inv_det;
    d33 =  (c11*c22 - c12*c21) * inv_det;
    /* compute 1-norm of matrix {c_ij}^{-1} */
    t1 = fabs(d11) + fabs(d12) + fabs(d13);
    t2 = fabs(d21) + fabs(d22) + fabs(d23);
    t3 = fabs(d31) + fabs(d32) + fabs(d33);
    t_inv = t1 > t2 ? t1:t2;
    t_inv = t_inv > t3 ? t_inv:t3;
    cond = t * t_inv;
    printf("inv_det=%f\n", inv_det);
    printf("condition number is: %f\n", cond);
  }
#endif

  err01 = d1sqr - MD_vec_dot(roh1,  roh1);
  err02 = d1sqr - MD_vec_dot(roh2,  roh2);
  err12 = d2sqr - MD_vec_dot(rh1h2, rh1h2);
  iter = 0;
  do {   /* large iteration loop */
    lmd01o = lmd01;
    lmd02o = lmd02;
    lmd12o = lmd12;
    /* compute right hand side */
    lmd01sqr = lmd01*lmd01;
    lmd02sqr = lmd02*lmd02;
    lmd12sqr = lmd12*lmd12;
    tmp1 = lmd01 * lmd02 * d1d1costheta;
    tmp2 = lmd12 * d1d2cosphi;

    rhs1 = 0.5 * ( err01 - (pophsqr*lmd01sqr + posqr*lmd02sqr)*d1sqr
	         - phsqr*lmd12sqr*d2sqr )
         - po*poph*tmp1  - (poph*lmd01 - po*lmd02)*ph*tmp2;
    rhs2 = 0.5 * ( err02 - (posqr*lmd01sqr + pophsqr*lmd02sqr)*d1sqr
	         - phsqr*lmd12sqr*d2sqr )
	 - po*poph*tmp1 + (po*lmd01 - poph*lmd02)*ph*tmp2;
    rhs3 = 0.5 * ( err12 - phsqr * ((lmd01sqr + lmd02sqr)*d1sqr
	       		     + 4.0*lmd12sqr*d2sqr) )
	 + phsqr*(tmp1 - 2.0*(lmd01+lmd02)*tmp2);
    /* solve cij*lmdj = rhsi */
    lmd01 = (  rhs1*c22*c33 + rhs2*c32*c13 + rhs3*c12*c23
	     - c13*c22*rhs3 - c23*c32*rhs1 - c33*c12*rhs2) 
             * inv_det;
    lmd02 = (  c11*rhs2*c33 + c21*rhs3*c13 + c31*rhs1*c23
	     - c13*rhs2*c31 - c23*rhs3*c11 - c33*rhs1*c21) 
             * inv_det;
    lmd12 = (  c11*c22*rhs3 + c21*c32*rhs1 + c31*c12*rhs2
	     - rhs1*c22*c31 - rhs2*c32*c11 - rhs3*c12*c21) 
             * inv_det;
    dlmd01 = lmd01 - lmd01o;
    dlmd02 = lmd02 - lmd02o;
    dlmd12 = lmd12 - lmd12o;
    iter ++;
  } while (dlmd01*dlmd01+dlmd02*dlmd02+dlmd12*dlmd12 > errTol2 &&
	   iter < maxiter);

#ifdef DEBUG_RATTLE
  printf("rattle: iter=%d, dlmd01=%g, dlmd02=%g, dlmd12=%g\n", iter, dlmd01, 
	 dlmd02, dlmd12);
#endif

  velo = vel;
  velh1 = vel+1;
  velh2 = vel+2;

  deltav.x = po * (lmd01 * r01.x + lmd02 * r02.x) / dt;
  deltav.y = po * (lmd01 * r01.y + lmd02 * r02.y) / dt;
  deltav.z = po * (lmd01 * r01.z + lmd02 * r02.z) / dt;
  MD_pvec_add(velo, &deltav, velo);
  MD_vec_mul_add(*ro, deltav, dt);

  deltav.x = ph * (-lmd01 * r01.x + lmd12 * r12.x) / dt;
  deltav.y = ph * (-lmd01 * r01.y + lmd12 * r12.y) / dt;
  deltav.z = ph * (-lmd01 * r01.z + lmd12 * r12.z) / dt;
  MD_pvec_add(velh1, &deltav, velh1);
  MD_vec_mul_add(*rh1, deltav, dt); 

  deltav.x = -ph * (lmd02 * r02.x + lmd12 * r12.x) / dt;
  deltav.y = -ph * (lmd02 * r02.y + lmd12 * r12.y) / dt;
  deltav.z = -ph * (lmd02 * r02.z + lmd12 * r12.z) / dt;
  MD_pvec_add(velh2, &deltav, velh2);
  MD_vec_mul_add(*rh2, deltav, dt); 

#ifdef DEBUG_RATTLE
  printf("after:\n");
  printf(" O: (%f,%f, %f)\n", ro->x, ro->y, ro->z);
  printf("H1: (%f,%f, %f)\n", rh1->x, rh1->y, rh1->z);
  printf("H2: (%f,%f, %f)\n", rh2->x, rh2->y, rh2->z);
#endif

  /* check the constraint: bond length. */
  MD_pvec_substract(ro, rh1, &roh1);
  MD_pvec_substract(ro, rh2, &roh2);
  MD_pvec_substract(rh1, rh2, &rh1h2);
  doh1sqr = MD_vec_dot(roh1, roh1);
  doh2sqr = MD_vec_dot(roh2, roh2);
  dh1h2sqr = MD_vec_dot(rh1h2, rh1h2);
  if (fabs(doh1sqr - d1sqr)/d1sqr > errTol ||
      fabs(doh2sqr - d1sqr)/d1sqr > errTol ||
      fabs(dh1h2sqr - d2sqr)/d2sqr > errTol) {
    fprintf(stderr, " O: (%f,%f, %f)\n", ro->x, ro->y, ro->z);
    fprintf(stderr, "H1: (%f,%f, %f)\n", rh1->x, rh1->y, rh1->z);
    fprintf(stderr, "H2: (%f,%f, %f)\n", rh2->x, rh2->y, rh2->z);
    fprintf(stderr, "*** rattle fails: |O-H1|^2=%f, |O-H2|^2 = %f,"
	    "|H1-H2|^2 = %f\n", doh1sqr, doh2sqr, dh1h2sqr);
    return MD_FAIL;
  }

  return OK;
}



MD_Errcode rattle2(struct Rattle_Water *rw, 
	           const MD_Dvec *r, MD_Dvec *v, const MD_Double dt)
{
  /* rigid water model parameters. d1 = d(O, H1) = d(O, H2),
   * d2 = d(H1, H2).  theta = angle(H1OH2). */
  const MD_Double errTol = rw->vel_errTol; /* bond length */
  const MD_Double inv_mo = 1.0 / rw->massO;
  const MD_Double inv_mh = 1.0 / rw->massH;
  MD_Dvec v10t, v20t, v21t;
  MD_Dvec r10, r20, r21;
  MD_Double c11, c12, c13, c21, c22, c23, c31, c32, c33;
  MD_Double inv_det;
  MD_Double b1, b2, b3;
  MD_Double inv_dt = 1.0 / dt;
  MD_Double u01, u02, u12;

  MD_vec_substract(v[1], v[0], v10t);
  MD_vec_substract(v[2], v[0], v20t);
  MD_vec_substract(v[2], v[1], v21t);
  MD_vec_substract(r[1], r[0], r10);
  MD_vec_substract(r[2], r[0], r20);
  MD_vec_substract(r[2], r[1], r21);

  /* must use the real values, not desired theoretical values */
  c11 = (inv_mo + inv_mh) * MD_vec_dot(r10, r10);
  c22 = (inv_mo + inv_mh) * MD_vec_dot(r20, r20);
  c33 = (inv_mh + inv_mh) * MD_vec_dot(r21, r21);
  c12 =  inv_mo * MD_vec_dot(r10, r20);
  c13 = -inv_mh * MD_vec_dot(r10, r21);
  c23 =  inv_mh * MD_vec_dot(r20, r21);
  c21 = c12;
  c22 = c11;
  c31 = c13;
  c32 = c23;
  
  inv_det = 1.0 /  (c11*c22*c33 + c21*c32*c13 + c31*c12*c23
                  - c13*c22*c31 - c23*c32*c11 - c33*c12*c21);

  /*
    fprintf(stderr, "<v1-v0, r1-r0> = %f\n", MD_vec_dot(v10t, r10));
    fprintf(stderr, "<v2-v0, r2-r0> = %f\n", MD_vec_dot(v20t, r20));
    fprintf(stderr, "<v2-v1, r2-r1> = %f\n", MD_vec_dot(v21t, r21));  
    
    fprintf(stderr, "|r1-r0| = %f\n", sqrt(MD_vec_dot(r10, r10)));
    fprintf(stderr, "|r2-r0| = %f\n", sqrt(MD_vec_dot(r20, r20)));
    fprintf(stderr, "|r2-r1| = %f\n", sqrt(MD_vec_dot(r21, r21)));
  */

  b1 = MD_vec_dot(v10t, r10) * inv_dt;
  b2 = MD_vec_dot(v20t, r20) * inv_dt;
  b3 = MD_vec_dot(v21t, r21) * inv_dt;

  u01 = (b1*c22*c33 + b2*c32*c13 + b3*c23*c12
      -  c13*c22*b3 - c23*c32*b1 - c33*b2*c12) * inv_det;
  u02 = (c11*b2*c33 + c21*b3*c13 + c31*c23*b1
      -  c13*b2*c31 - c23*b3*c11 - c33*c21*b1) * inv_det;
  u12 = (c11*c22*b3 + c21*c32*b1 + c31*b2*c12
      -  b1*c22*c31 - b2*c32*c11 - b3*c21*c12) * inv_det;

  v[0].x += dt*inv_mo * ( u01*r10.x + u02*r20.x);
  v[0].y += dt*inv_mo * ( u01*r10.y + u02*r20.y);
  v[0].z += dt*inv_mo * ( u01*r10.z + u02*r20.z);
  v[1].x += dt*inv_mh * (-u01*r10.x + u12*r21.x);
  v[1].y += dt*inv_mh * (-u01*r10.y + u12*r21.y);
  v[1].z += dt*inv_mh * (-u01*r10.z + u12*r21.z);
  v[2].x += dt*inv_mh * (-u02*r20.x - u12*r21.x);
  v[2].y += dt*inv_mh * (-u02*r20.y - u12*r21.y);
  v[2].z += dt*inv_mh * (-u02*r20.z - u12*r21.z);

  /* check orthogonality */
  MD_vec_substract(v[1], v[0], v10t);
  MD_vec_substract(v[2], v[0], v20t);
  MD_vec_substract(v[2], v[1], v21t);
  if (fabs(MD_vec_dot(v10t, r10)) > errTol ||
      fabs(MD_vec_dot(v20t, r20)) > errTol ||
      fabs(MD_vec_dot(v21t, r21)) > errTol) {
    fprintf(stderr, "<v1-v0, r1-r0> = %g\n", MD_vec_dot(v10t, r10));
    fprintf(stderr, "<v2-v0, r2-r0> = %g\n", MD_vec_dot(v20t, r20));
    fprintf(stderr, "<v2-v1, r2-r1> = %g\n", MD_vec_dot(v21t, r21));    
    fprintf(stderr, "rattle2 failed\n");
    return FAILURE; 
  }

  return OK;
}

