/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/*
 * 
 *
 */

#include "settle.h"
#include "utilities.h"
#include "helper.h"
#include <string.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#define EPSILON 1e-8

MD_Errcode settle_init(struct Settle_Water* sw, MD_Double mO, MD_Double mH,
		       MD_Double hhdist, MD_Double ohdist, MD_Int natoms)
{
  MD_Double t1 = 0.5*mO/mH;
  sw->mO = mO;
  sw->mH = mH;
  sw->natoms = natoms;
  sw->old_pos = my_calloc((size_t)natoms, sizeof(MD_Dvec), "old_pos");
  sw->rc = 0.5 * hhdist;
  sw->ra = sqrt(ohdist*ohdist-sw->rc*sw->rc)/(1.0+t1);
  sw->rb = t1 * sw->ra;

  printf("use SETTLE method\n");

  return OK;
}


MD_Errcode settle_destroy(struct Settle_Water* sw)
{
  free(sw->old_pos);
  memset(sw, 0, sizeof(struct Settle_Water));
  return OK;
}


MD_Errcode settle1(const struct Settle_Water* sw, const MD_Int iatom,
		   MD_Dvec rnew[3], MD_Dvec vel[3], const MD_Double dt)
{
  const MD_Double mO = sw->mO;
  const MD_Double mH = sw->mH;
  const MD_Double ra = sw->ra;
  const MD_Double rb = sw->rb;
  const MD_Double rc = sw->rc;
  const MD_Dvec *ref = sw->old_pos + iatom;
  MD_Dvec b0, c0, d0, a1, b1, c1, a2, b2, c2, a3, b3, c3;
  MD_Dvec n1, n2, n0, m1, m2, m0;
  MD_Dvec v;
  MD_Double sinphi, cosphi, sinpsi, cospsi;
  MD_Double rbsphi, rbcphi;
  MD_Double alpha, beta, gamma;
  MD_Double sintheta, costheta, a2b2;
  MD_Double tmp, tmp1, tmp2;

  /* vectors in the plane of the original positions */
  MD_vec_substract(ref[1], ref[0], b0);
  MD_vec_substract(ref[2], ref[0], c0);

  tmp = 1.0 / (mO+mH+mH);
  d0.x = (rnew[0].x*mO + rnew[1].x*mH + rnew[2].x*mH) * tmp;
  d0.y = (rnew[0].y*mO + rnew[1].y*mH + rnew[2].y*mH) * tmp;
  d0.z = (rnew[0].z*mO + rnew[1].z*mH + rnew[2].z*mH) * tmp;
  
  MD_vec_substract(rnew[0], d0, a1);
  MD_vec_substract(rnew[1], d0, b1);
  MD_vec_substract(rnew[2], d0, c1);

  /* Vectors describing transformation from original coordinate system to
     the 'primed' coordinate system as in the diagram. */
  MD_vec_cross(b0, c0, n0);
  MD_vec_cross(a1, n0, n1);
  MD_vec_cross(n0, n1, n2);
  tmp = 1.0/MD_vecLen(n0); MD_vec_mul(n0, tmp, n0);
  tmp = 1.0/MD_vecLen(n1); MD_vec_mul(n1, tmp, n1);
  tmp = 1.0/MD_vecLen(n2); MD_vec_mul(n2, tmp, n2);

  v.x = MD_vec_dot(n1,b0); 
  v.y = MD_vec_dot(n2,b0);
  v.z = MD_vec_dot(n0,b0);
  b0.x=v.x; b0.y=v.y; b0.z=v.z;
  v.x = MD_vec_dot(n1,c0);
  v.y = MD_vec_dot(n2,c0);
  v.z = MD_vec_dot(n0,c0);
  c0.x=v.x; c0.y=v.y; c0.z=v.z;
  v.x = MD_vec_dot(n1,a1);
  v.y = MD_vec_dot(n2,a1);
  v.z = MD_vec_dot(n0,a1);
  a1.x=v.x; a1.y=v.y; a1.z=v.z;
  v.x = MD_vec_dot(n1,b1);
  v.y = MD_vec_dot(n2,b1);
  v.z = MD_vec_dot(n0,b1);
  b1.x=v.x; b1.y=v.y; b1.z=v.z;
  v.x = MD_vec_dot(n1,c1);
  v.y = MD_vec_dot(n2,c1);
  v.z = MD_vec_dot(n0,c1);
  c1.x=v.x; c1.y=v.y; c1.z=v.z;

  /* now we can compute positions of canonical water */
  sinphi = a1.z/ra;
  tmp = 1.0-sinphi*sinphi;
  ASSERT(tmp >= 0.0);
  cosphi = sqrt(tmp);
  sinpsi = (b1.z - c1.z)/(2.0*rc*cosphi);
  tmp = 1.0-sinpsi*sinpsi;
  ASSERT(tmp >= 0.0);
  cospsi = sqrt(tmp);

  rbcphi = -rb*cosphi;
  rbsphi = -rb*sinphi;
  tmp  = rc*sinpsi;
  tmp1 = tmp*sinphi;
  tmp2 = tmp*cosphi;
  a2.x=0.0;        a2.y=ra*cosphi;   a2.z=ra*sinphi;
  b2.x=-rc*cospsi; b2.y=rbcphi-tmp1; b2.z=rbsphi+tmp2; 
  c2.x= rc*cospsi; c2.y=rbcphi+tmp1; c2.z=rbsphi-tmp2; 

  /* there are no a0 terms because we've already subtracted the term off 
     when we first defined b0 and c0. */
  alpha = b2.x*(b0.x - c0.x) + b0.y*b2.y + c0.y*c2.y;
  beta  = b2.x*(c0.y - b0.y) + b0.x*b2.y + c0.x*c2.y;
  gamma = b0.x*b1.y - b1.x*b0.y + c0.x*c1.y - c1.x*c0.y;
 
  a2b2 = alpha*alpha + beta*beta;
  sintheta = (alpha*gamma - beta*sqrt(a2b2 - gamma*gamma))/a2b2;
  costheta = sqrt(1.0 - sintheta*sintheta);
  
  a3.x =-a2.y*sintheta;
  a3.y = a2.y*costheta;
  a3.z = a1.z;
  b3.x = b2.x*costheta - b2.y*sintheta;
  b3.y = b2.x*sintheta + b2.y*costheta;
  b3.z = b1.z;
  c3.x =-b2.x*costheta - c2.y*sintheta;
  c3.y =-b2.x*sintheta + c2.y*costheta;
  c3.z = c1.z;

  /* undo the transformation; generate new normal vectors from the transpose.*/
  m1.x=n1.x; m1.y=n2.x; m1.z=n0.x;
  m2.x=n1.y; m2.y=n2.y; m2.z=n0.y;
  m0.x=n1.z; m0.y=n2.z; m0.z=n0.z;

  rnew[0].x = MD_vec_dot(a3, m1) + d0.x;
  rnew[0].y = MD_vec_dot(a3, m2) + d0.y;
  rnew[0].z = MD_vec_dot(a3, m0) + d0.z;
  rnew[1].x = MD_vec_dot(b3, m1) + d0.x;
  rnew[1].y = MD_vec_dot(b3, m2) + d0.y;
  rnew[1].z = MD_vec_dot(b3, m0) + d0.z;
  rnew[2].x = MD_vec_dot(c3, m1) + d0.x;
  rnew[2].y = MD_vec_dot(c3, m2) + d0.y;
  rnew[2].z = MD_vec_dot(c3, m0) + d0.z;

#ifdef DEBUG_SETTLE
  tmp = sqrt((ra+rb)*(ra+rb) + rc*rc);
  MD_vec_substract(rnew[0], rnew[1], v);
  ASSERT(fabs(MD_vec_dot(v,v) - tmp*tmp) < EPSILON);  
  MD_vec_substract(rnew[0], rnew[2], v);
  ASSERT(fabs(MD_vec_dot(v,v) - tmp*tmp) < EPSILON);
  MD_vec_substract(rnew[2], rnew[1], v);
  ASSERT(fabs(MD_vec_dot(v,v) - 4.0*rc*rc) < EPSILON);
#endif
  
  tmp = 1.0/dt;
  vel[0].x = (rnew[0].x - ref[0].x) * tmp;  /* good ? */
  vel[0].y = (rnew[0].y - ref[0].y) * tmp;
  vel[0].z = (rnew[0].z - ref[0].z) * tmp;
  vel[1].x = (rnew[1].x - ref[1].x) * tmp;
  vel[1].y = (rnew[1].y - ref[1].y) * tmp;
  vel[1].z = (rnew[1].z - ref[1].z) * tmp;
  vel[2].x = (rnew[2].x - ref[2].x) * tmp;
  vel[2].y = (rnew[2].y - ref[2].y) * tmp;
  vel[2].z = (rnew[2].z - ref[2].z) * tmp;

  return OK;
}


/* a copy of rattle2 */
MD_Errcode settle2(const struct Settle_Water *rw, const MD_Dvec *r, 
		   MD_Dvec *v, const MD_Double dt)
{
  /* rigid water model parameters. d1 = d(O, H1) = d(O, H2),
   * d2 = d(H1, H2).  theta = angle(H1OH2). */
  const MD_Double errTol = 1e-6; /* bond length */
  const MD_Double inv_mo = 1.0 / rw->mO;
  const MD_Double inv_mh = 1.0 / rw->mH;
  MD_Dvec v10t, v20t, v21t;
  MD_Dvec r10, r20, r21;
  MD_Double c11, c12, c13, c21, c22, c23, c31, c32, c33;
  MD_Double inv_det;
  MD_Double b1, b2, b3;
  MD_Double inv_dt = 1.0 / dt;
  MD_Double u01, u02, u12;

  MD_vec_substract(v[1], v[0], v10t);
  MD_vec_substract(v[2], v[0], v20t);
  MD_vec_substract(v[2], v[1], v21t);
  MD_vec_substract(r[1], r[0], r10);
  MD_vec_substract(r[2], r[0], r20);
  MD_vec_substract(r[2], r[1], r21);

  /* must use the real values, not desired theoretical values */
  c11 = (inv_mo + inv_mh) * MD_vec_dot(r10, r10);
  c22 = (inv_mo + inv_mh) * MD_vec_dot(r20, r20);
  c33 = (inv_mh + inv_mh) * MD_vec_dot(r21, r21);
  c12 =  inv_mo * MD_vec_dot(r10, r20);
  c13 = -inv_mh * MD_vec_dot(r10, r21);
  c23 =  inv_mh * MD_vec_dot(r20, r21);
  c21 = c12;
  c22 = c11;
  c31 = c13;
  c32 = c23;
  
  inv_det = 1.0 /  (c11*c22*c33 + c21*c32*c13 + c31*c12*c23
                  - c13*c22*c31 - c23*c32*c11 - c33*c12*c21);

  /*
    fprintf(stderr, "<v1-v0, r1-r0> = %f\n", MD_vec_dot(v10t, r10));
    fprintf(stderr, "<v2-v0, r2-r0> = %f\n", MD_vec_dot(v20t, r20));
    fprintf(stderr, "<v2-v1, r2-r1> = %f\n", MD_vec_dot(v21t, r21));  
    
    fprintf(stderr, "|r1-r0| = %f\n", sqrt(MD_vec_dot(r10, r10)));
    fprintf(stderr, "|r2-r0| = %f\n", sqrt(MD_vec_dot(r20, r20)));
    fprintf(stderr, "|r2-r1| = %f\n", sqrt(MD_vec_dot(r21, r21)));
  */

  b1 = MD_vec_dot(v10t, r10) * inv_dt;
  b2 = MD_vec_dot(v20t, r20) * inv_dt;
  b3 = MD_vec_dot(v21t, r21) * inv_dt;

  u01 = (b1*c22*c33 + b2*c32*c13 + b3*c23*c12
      -  c13*c22*b3 - c23*c32*b1 - c33*b2*c12) * inv_det;
  u02 = (c11*b2*c33 + c21*b3*c13 + c31*c23*b1
      -  c13*b2*c31 - c23*b3*c11 - c33*c21*b1) * inv_det;
  u12 = (c11*c22*b3 + c21*c32*b1 + c31*b2*c12
      -  b1*c22*c31 - b2*c32*c11 - b3*c21*c12) * inv_det;

  v[0].x += dt*inv_mo * ( u01*r10.x + u02*r20.x);
  v[0].y += dt*inv_mo * ( u01*r10.y + u02*r20.y);
  v[0].z += dt*inv_mo * ( u01*r10.z + u02*r20.z);
  v[1].x += dt*inv_mh * (-u01*r10.x + u12*r21.x);
  v[1].y += dt*inv_mh * (-u01*r10.y + u12*r21.y);
  v[1].z += dt*inv_mh * (-u01*r10.z + u12*r21.z);
  v[2].x += dt*inv_mh * (-u02*r20.x - u12*r21.x);
  v[2].y += dt*inv_mh * (-u02*r20.y - u12*r21.y);
  v[2].z += dt*inv_mh * (-u02*r20.z - u12*r21.z);

  /* check orthogonality */
  MD_vec_substract(v[1], v[0], v10t);
  MD_vec_substract(v[2], v[0], v20t);
  MD_vec_substract(v[2], v[1], v21t);
  if (fabs(MD_vec_dot(v10t, r10)) > errTol ||
      fabs(MD_vec_dot(v20t, r20)) > errTol ||
      fabs(MD_vec_dot(v21t, r21)) > errTol) {
    fprintf(stderr, "<v1-v0, r1-r0> = %g\n", MD_vec_dot(v10t, r10));
    fprintf(stderr, "<v2-v0, r2-r0> = %g\n", MD_vec_dot(v20t, r20));
    fprintf(stderr, "<v2-v1, r2-r1> = %g\n", MD_vec_dot(v21t, r21));    
    fprintf(stderr, "rattle2 failed\n");
    return FAILURE; 
  }

  return OK;
}



void settle_prepare(struct Settle_Water* sw, const MD_Dvec *old_pos)
{
  memcpy(sw->old_pos, old_pos, sw->natoms * sizeof(MD_Dvec));
}
