/*
 * Copyright (C) 2007 by David J. Hardy.  All rights reserved.
 */

/*
 * bucksafe.c
 *
 * Determine params for smooth extension of Buckingham potential.
 *
 * Attach to  U(r) = a exp(-r/b) - c/r^6  at the change in curvature
 * point along the wall,  U''(r_0) = 0  such that  U(r_0) > 0, a piece
 * of the form  A/r^6 + B  (A>0, B>0)  such that the join is C^1.
 * This avoids the non-physical well and can be safely used for dynamics.
 * It will be especially beneficial for minimization.
 */

#include <stdio.h>
#include <math.h>
#include "force/force.h"
#undef DEBUG_WATCH
#include "debug/debug.h"


/*
 * Buckingham potential and first two derivatives computed
 * as function of r and constant params a, b, c.
 */
static double buck(double r, double a, double b, double c);
static double d_buck(double r, double a, double b, double c);
static double dd_buck(double r, double a, double b, double c);

/*
 * Customized bisection method for Buckingham.
 */
static int bisection(
    double *root,
    double (*f)(double x, double a, double b, double c),
    double x0,
    double x1,
    double a,
    double b,
    double c,
    double tol);


#define R_SMALL  (1./128)
#define MAX_ITER 100
#define TOL_LO   1e-6
#define TOL_HI   1e-12

/*
 * To be called externally to determine extra parameters to Buckingham
 * potential,  a exp(-r/b) + c/r^6,  for piecewise-defined smooth extension
 * having form  A/r^6 + B.
 *
 *   A  -- return parameter A
 *   B  -- return parameter B
 *   Rswitch -- return switch distance within which extension is to be active
 *   uRswitch -- return energy at switch distance (kcal/mol)
 *   Rtop -- return distance for top of potential barrier
 *   uRtop -- return energy at top of potential barrier
 *   a  -- constant for Buckingham
 *   b  -- constant for Buckingham
 *   c  -- constant for Buckingham
 *
 * Returns 0 on success or FORCE_FAIL if something goes wrong.
 */
int force_safe_buckingham_params(
    double *A,
    double *B,
    double *Rswitch,
    double *uRswitch,
    double *Rtop,
    double *uRtop,
    double a,
    double b,
    double c
    )
{
  double r0;
  double ur0;
  double dur0;
  double r1;
  double ur1;
  double fac;
  double root0;
  double root1;
  double droot0;
  double ddroot0;
  double (*U)(double, double, double, double) = buck;
  double (*dU)(double, double, double, double) = d_buck;
  double (*ddU)(double, double, double, double) = dd_buck;
  int cnt;

  if (0.0==a && 0.0==c) {
    *A = 0;
    *B = 0;
    *Rswitch = 0;
    *uRswitch = 0;
    *Rtop = 0;
    *uRtop = 0;
    return 0;
  }

  r0 = R_SMALL;
  ur0 = U(r0, a, b, c);
  fac = (ur0 < 0 ? 1.5 : 0.75);
  r1 = r0;
  ur1 = ur0;

  /* bracket root: find sign change in U */
  cnt = 0;
  while (ur0 * ur1 >= 0 && cnt++ < MAX_ITER) {
    r1 *= fac;
    ur1 = U(r1, a, b, c);
  }
  if (cnt == MAX_ITER) {
    printf("# ERROR (bucksafe.c, line %d): unable to bracket root\n",
        __LINE__);
    return FORCE_FAIL;
  }
  /* this should be smallest root of U (no need for high tolerance) */
  if (bisection(&root0, U, r0, r1, a, b, c, TOL_LO)) {
    printf("# ERROR (bucksafe.c, line %d): bisection failed to find root\n",
        __LINE__);
    return FORCE_FAIL;
  }

  r0 = (r1 > root0 ? r1 : r0);
  ur0 = U(r0, a, b, c);
  r1 = r0;
  ur1 = ur0;
  /* (sanity check) */
  if (ur0 < 0) {
    printf("# ERROR (bucksafe.c, line %d): sign error\n", __LINE__);
    return FORCE_FAIL;
  }
  fac = 1.5;

  /* bracket root: find sign change in U */
  cnt = 0;
  while (ur0 * ur1 >= 0 && cnt++ < MAX_ITER) {
    r1 *= fac;
    ur1 = U(r1, a, b, c);
  }
  if (cnt == MAX_ITER) {
    printf("# ERROR (bucksafe.c, line %d): unable to bracket root\n",
        __LINE__);
    return FORCE_FAIL;
  }
  /* this should be largest root of U (no need for high tolerance) */
  if (bisection(&root1, U, r0, r1, a, b, c, TOL_LO)) {
    printf("# ERROR (bucksafe.c, line %d): bisection failed to find root\n",
        __LINE__);
    return FORCE_FAIL;
  }

  /* (sanity check) */
  if (root0 >= root1) {
    printf("# ERROR (bucksafe.c, line %d): failed to find expected roots\n",
        __LINE__);
    return FORCE_FAIL;
  }

  /* this should be smallest root of U'' (want high tolerance) */
  if (bisection(&ddroot0, ddU, root0, root1, a, b, c, TOL_HI)) {
    printf("# ERROR (bucksafe.c, line %d): bisection failed to find root\n",
        __LINE__);
    return FORCE_FAIL;
  }
  r0 = ddroot0;
  ur0 = U(r0, a, b, c);
  dur0 = dU(r0, a, b, c);
  *A = (-1./6) * dur0 * ((r0*r0*r0)*(r0*r0*r0)*r0);
  *B = ur0 - *A / ((r0*r0*r0)*(r0*r0*r0));
  *Rswitch = r0;
  *uRswitch = ur0;

  /* find barrier height - for diagnostic purposes */
  if (bisection(&droot0, dU, root0, root1, a, b, c, TOL_HI)) {
    printf("# ERROR (bucksafe.c, line %d): bisection failed to find root\n",
        __LINE__);
    return FORCE_FAIL;
  }
  *Rtop = droot0;
  *uRtop = U(droot0, a, b, c);

#if 0
  {
    double droot0;
    /* find barrier height */
    if (bisection(&droot0, dU, root0, root1, a, b, c, TOL_HI)) {
      printf("# ERROR (bucksafe.c, line %d): bisection failed to find root\n",
          __LINE__);
      return FORCE_FAIL;
    }
    printf("# Buckingham: barrier r=%.12g A, height U(r)=%.12g kcal/mol\n",
        droot0, U(droot0, a, b, c));
    printf("# extending with join at r=%.12g A, height U(r)=%.12g kcal/mol\n",
        ddroot0, U(ddroot0, a, b, c));
  }
#endif

  return 0;
}


double buck(double r, double a, double b, double c)
{
  double r6 = r*r*r * r*r*r;
  return (a * exp(-r/b) - c/r6);
}


double d_buck(double r, double a, double b, double c)
{
  double r7 = (r*r*r) * (r*r*r) * r;
  return (-a/b * exp(-r/b) + 6.*c/r7);
}


double dd_buck(double r, double a, double b, double c)
{
  double r8 = ((r*r)*(r*r)) * ((r*r)*(r*r));
  return (a/(b*b) * exp(-r/b) - 42.*c/r8);
}


int bisection(
    double *root,
    double (*f)(double x, double a, double b, double c),
    double x0,
    double x1,
    double a,
    double b,
    double c,
    double tol)
{
  double fx0;
  double fx1;
  double m = 0.5 * (x0+x1);  /* needs to be initialized */
  double fm;
  int sfx0;
  int sfx1;
  int sfm;
  int cnt;

  /* swap values if in wrong order */
  if (x1 < x0) {
    double tmp = x1;
    x1 = x0;
    x0 = tmp;
#ifdef DEBUG_WATCH
    printf("bisection (DIAGNOSTIC) swapped x0 and x1\n");
    printf("x0=%.12g  x1=%.12g\n", x0, x1);
#endif
  }

  /* make sure these endpoints give f(x) of opposite signs */
  fx0 = f(x0, a, b, c);
  fx1 = f(x1, a, b, c);
  sfx0 = (fx0 > 0 ? 1 : -1);
  sfx1 = (fx1 > 0 ? 1 : -1);
  if (sfx0 == sfx1) {
#ifdef DEBUG_WATCH
    printf("bisection (FAILURE) function endpoints must have opposite sign\n");
    printf("f(x0)=%.12g  f(x1)=%.12g\n", fx0, fx1);
#endif
    return FORCE_FAIL;
  }

  cnt = 0;
  while ((x1-x0) > tol && cnt++ < MAX_ITER) {
    m = 0.5 * (x0+x1);  /* midpoint */
    fm = f(m, a, b, c);
    sfm = (fm > 0 ? 1 : -1);
    if (sfm == sfx0) {
      x0 = m;
    }
    else {
      x1 = m;
    }
  }
  if (cnt == MAX_ITER) {
#ifdef DEBUG_WATCH
    printf("bisection (FAILURE) exceeded max iteration count %d\n", cnt);
    printf("remaining interval:  x0=%.12g  x1=%.12g\n", x0, x1);
#endif
    return FORCE_FAIL;
  }
#ifdef DEBUG_WATCH
  printf("bisection (SUCCESS) found root %.12g\n", m);
#endif
  *root = m;
  return 0;
}
