/*
 * Copyright (C) 2006 by David J. Hardy.  All rights reserved.
 *
 * cgmin.c  - Compute energy minimization using conjugate gradient method
 *            with golden section search.
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include "step/step_defn.h"
#undef DEBUG_WATCH
#include "debug/debug.h"


int step_init_cgmin(Step *s)
{
  const StepParam *param = s->param;

  s->starting_position = (MD_Dvec *) malloc(param->natoms * sizeof(MD_Dvec));
  if (NULL == s->starting_position) return STEP_FAILURE;

  s->search_direction = (MD_Dvec *) calloc(param->natoms, sizeof(MD_Dvec));
  if (NULL == s->search_direction) return STEP_FAILURE;

  if (step_output(s, "using bracket interval distance %g Angstroms\n",
        param->cgmin_dis)
      || step_output(s, "using line search convergence tolerance "
        "%g Angstroms\n", param->cgmin_tol)
      || step_output(s, "setting maximum %d force evaluations "
        "for each line search\n", param->cgmin_eval)) {
    return STEP_FAILURE;
  }
  return STEP_SUCCESS;
}


void step_done_cgmin(Step *s)
{
  free(s->starting_position);
  free(s->search_direction);
}


#define GOLDEN_RATIO  (0.618033988749895)
/* tau=(sqrt(5)-1)/2,  solution to  tau^2 = 1-tau */

static int bracket_and_golden_section_search(
    Step *step,
    MD_Dvec *initpos,  /* initial position */
    MD_Dvec *s,        /* search direction */
    MD_Dvec *pos,      /* should be initialized to initpos,
                          returns min energy pos in search direction */
    MD_Dvec *f,        /* should be initialized to be force for pos,
                          returns force for min energy pos */
    double dis,        /* largest displacement along search direction */
    double tol,        /* tolerance for convergence of search interval */
    int32 max_f_eval,  /* maximum number of force evaluations before give up */
    int32 natoms,      /* number of atoms */
    double *a,         /* returns the search factor */
    double *u          /* should be initialized to be potential for pos,
                          returns potential for min energy pos */
    )
{
  const double tau = GOLDEN_RATIO;

  double amin, amax;
  double delta;
  double a1;
  double a2;
  double u_amin = *u;
  double u_amax, u_a1, u_a2;
  int32 num_f_eval = 0;
  int32 is_bracket;
  double smax2, s2, smax;
  int32 i;

  /* use s and dis2 to determine amax search factor */
  smax2 = 0.0;
  for (i = 0;  i < natoms;  i++) {
    s2 = s[i].x * s[i].x + s[i].y * s[i].y + s[i].z * s[i].z;
    if (smax2 < s2) smax2 = s2;
  }
  smax = sqrt(smax2);
  amax = dis / smax;
  /* printf("max atom move is %.12g Angstroms\n", amax * sqrt(smax2)); */

  amin = 0.0;
  delta = amax - amin;
  a1 = amin + (1-tau)*delta;
  a2 = amin + tau*delta;

  /* interval is considered trivially bracketed if small enough */
  is_bracket = (delta*smax <= tol);

  if (step_output(step, "scaled bracket search interval [%g,%g]\n",
       amin*smax, amax*smax)) {
    return STEP_FAILURE;
  }

  /* find potential for amax */
  if (++num_f_eval > max_f_eval) {
    return step_error(step, "search exceeded %d force evaluations\n",
        max_f_eval);
  }
  for (i = 0;  i < natoms;  i++) {
    pos[i].x = initpos[i].x + amax*s[i].x;
    pos[i].y = initpos[i].y + amax*s[i].y;
    pos[i].z = initpos[i].z + amax*s[i].z;
  }
  if (step_force(step)) {
    return step_error(step, "force evaluation number %d failed\n",
        num_f_eval);
  }
  u_amax = step->system->potential_energy;

  /* find potential for a1 */
  if (++num_f_eval > max_f_eval) {
    return step_error(step, "search exceeded %d force evaluations\n",
        max_f_eval);
  }
  for (i = 0;  i < natoms;  i++) {
    pos[i].x = initpos[i].x + a1*s[i].x;
    pos[i].y = initpos[i].y + a1*s[i].y;
    pos[i].z = initpos[i].z + a1*s[i].z;
  }
  if (step_force(step)) {
    return step_error(step, "force evaluation number %d failed\n",
        num_f_eval);
  }
  u_a1 = step->system->potential_energy;

  /* find potential for a2 */
  if (++num_f_eval > max_f_eval) {
    return step_error(step, "search exceeded %d force evaluations\n",
        max_f_eval);
  }
  for (i = 0;  i < natoms;  i++) {
    pos[i].x = initpos[i].x + a2*s[i].x;
    pos[i].y = initpos[i].y + a2*s[i].y;
    pos[i].z = initpos[i].z + a2*s[i].z;
  }
  if (step_force(step)) {
    return step_error(step, "force evaluation number %d failed\n",
        num_f_eval);
  }
  u_a2 = step->system->potential_energy;

  /* save most recent computation */
  *a = a2;
  *u = u_a2;

  while ( ! is_bracket) {

    if (u_a1 >= u_amin) {
      /* shrink bracketing interval to [amin,a1] */
      /* compute new u_a1, u_a2 */
      if (step_output(step, "shrinking bracket interval to [%g,%g]\n",
           amin*smax, a1*smax)) {
        return STEP_FAILURE;
      }

      amax = a1;
      u_amax = u_a1;

      delta = amax - amin;
      a1 = amin + (1-tau)*delta;
      a2 = amin + tau*delta;

      /* find potential for a1 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a1*s[i].x;
        pos[i].y = initpos[i].y + a1*s[i].y;
        pos[i].z = initpos[i].z + a1*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a1 = step->system->potential_energy;

      /* find potential for a2 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a2*s[i].x;
        pos[i].y = initpos[i].y + a2*s[i].y;
        pos[i].z = initpos[i].z + a2*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a2 = step->system->potential_energy;

      /* update is_bracket since interval has shrunk */
      is_bracket = (delta*smax <= tol);

      /* save most recent computation */
      *a = a2;
      *u = u_a2;
    }

    else if (u_a2 >= u_amin) {
      /* shrink bracketing interval to [amin,a2] */
      /* compute new u_a1 */
      if (step_output(step, "shrinking bracket interval to [%g,%g]\n",
           amin*smax, a2*smax)) {
        return STEP_FAILURE;
      }

      amax = a2;
      u_amax = u_a2;
      a2 = a1;
      u_a2 = u_a1;

      delta = amax - amin;
      a1 = amin + (1-tau)*delta;

      /* find potential for a1 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a1*s[i].x;
        pos[i].y = initpos[i].y + a1*s[i].y;
        pos[i].z = initpos[i].z + a1*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a1 = step->system->potential_energy;

      /* update is_bracket since interval has shrunk */
      is_bracket = (delta*smax <= tol);

      /* save most recent computation */
      *a = a1;
      *u = u_a1;
    }

    else if (u_amax < u_a1 && u_amax < u_a2) {
      /* shift bracketing interval to [a2,a2+delta] */
      /* compute new u_a2, u_amax */
      if (step_output(step, "shifting bracket interval to [%g,%g]\n",
            a2*smax, (a2+delta)*smax)) {
        return STEP_FAILURE;
      }

      amin = a2;
      u_amin = u_a2;
      a1 = amax;
      u_a1 = u_amax;

      amax = amin + delta;
      a2 = amin + tau*delta;

      /* find potential for amax */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + amax*s[i].x;
        pos[i].y = initpos[i].y + amax*s[i].y;
        pos[i].z = initpos[i].z + amax*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_amax = step->system->potential_energy;

      /* find potential for a2 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a2*s[i].x;
        pos[i].y = initpos[i].y + a2*s[i].y;
        pos[i].z = initpos[i].z + a2*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a2 = step->system->potential_energy;
    }

    else {
      /* now we consider bracketed interval unimodal */
      /* continue with golden section search */
      if (step_output(step, "found acceptable bracket interval [%g,%g]\n",
            amin*smax, amax*smax)) {
        return STEP_FAILURE;
      }

      is_bracket = STEP_TRUE;
    }

  } /* end while not bracketed */

  /* golden section search */
  if (step_output(step, "starting golden section search on interval [%g,%g]\n",
        amin*smax, amax*smax)) {
    return STEP_FAILURE;
  }

  while (delta*smax > tol) {

    if (u_a1 > u_a2) {
      if (step_output(step, "shrinking search interval to [%g,%g]\n",
            a1*smax, amax*smax)) {
        return STEP_FAILURE;
      }

      amin = a1;
      u_amin = u_a1;
      delta = amax - amin;

      a1 = a2;
      u_a1 = u_a2;

      a2 = amin + tau*delta;

      /* find potential for a2 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a2*s[i].x;
        pos[i].y = initpos[i].y + a2*s[i].y;
        pos[i].z = initpos[i].z + a2*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a2 = step->system->potential_energy;

      /* save most recent computation */
      *a = a2;
      *u = u_a2;
    }

    else {
      if (step_output(step, "shrinking search interval to [%g,%g]\n",
            amin*smax, a2*smax)) {
        return STEP_FAILURE;
      }

      amax = a2;
      u_amax = u_a2;
      delta = amax - amin;

      a2 = a1;
      u_a2 = u_a1;

      a1 = amin + (1-tau)*delta;

      /* find potential for a1 */
      if (++num_f_eval > max_f_eval) {
        return step_error(step, "search exceeded %d force evaluations\n",
            max_f_eval);
      }
      for (i = 0;  i < natoms;  i++) {
        pos[i].x = initpos[i].x + a1*s[i].x;
        pos[i].y = initpos[i].y + a1*s[i].y;
        pos[i].z = initpos[i].z + a1*s[i].z;
      }
      if (step_force(step)) {
        return step_error(step, "force evaluation number %d failed\n",
            num_f_eval);
      }
      u_a1 = step->system->potential_energy;

      /* save most recent computation */
      *a = a1;
      *u = u_a1;
    }

  } /* end while golden section search */

  if (step_output(step, "successfully found interval [%g,%g] "
        "within tolerance\n", amin*smax, amax*smax)) {
    return STEP_FAILURE;
  }

  return 0;
}


int step_compute_cgmin(Step *step, int32 numsteps)
{
  MD_Dvec *f = step->system->force;
  MD_Dvec *pos = step->system->pos;
  MD_Dvec *s = step->search_direction;
  MD_Dvec *initpos = step->starting_position;
  double accum_fdf, old_fdf, a, beta;
  const int32 natoms = step->param->natoms;
  const int32 resultsFreq = step->param->resultsFreq;
  int32 n, i, resultsCounter;
  const int32 cgmin_eval = step->param->cgmin_eval;
  const double cgmin_dis = step->param->cgmin_dis;
  const double cgmin_tol = step->param->cgmin_tol;

  if (step_output(step, "Running conjugate gradient minimization (CGMIN) "
        "for %d steps...\n", numsteps)) {
    return STEP_FAILURE;
  }

  /* compute initial force */
  if (step_force(step)) return STEP_FAILURE;
  step->u = step->system->potential_energy;

  /* use force to set initial search direction */
  memcpy(s, f, natoms * sizeof(MD_Dvec));

  /* find f dot f */
  accum_fdf = 0.0;
  for (i = 0;  i < natoms;  i++) {
    accum_fdf += f[i].x * f[i].x + f[i].y * f[i].y + f[i].z * f[i].z;
  }
  step->fdf = accum_fdf;

  /* check for saddle point */
  if (step->fdf < cgmin_tol) {
    return step_error(step, "found approximate saddle point, "
        "terminating early\n");
  }

  /* conjugate gradient loop */
  resultsCounter = 0;
  for (n = 0;  n < numsteps;  n++) {

    /* retain initial position */
    memcpy(initpos, pos, natoms * sizeof(MD_Dvec));

    /* find minimum along search direction */
    if (bracket_and_golden_section_search(step, initpos, s, pos, f,
          cgmin_dis, cgmin_tol, cgmin_eval, natoms, &a, &(step->u))) {
      return STEP_FAILURE;
    }

    old_fdf = step->fdf;

    /* find f dot f */
    accum_fdf = 0.0;
    for (i = 0;  i < natoms;  i++) {
      accum_fdf += f[i].x * f[i].x + f[i].y * f[i].y + f[i].z * f[i].z;
    }
    step->fdf = accum_fdf;

    /* check for saddle point */
    if (step->fdf < cgmin_tol) {
      return step_error(step, "found approximate saddle point, "
          "terminating early\n");
    }

    /* determine new search direction */
    beta = accum_fdf / old_fdf;
    for (i = 0;  i < natoms;  i++) {
      s[i].x = f[i].x + beta * s[i].x;
      s[i].y = f[i].y + beta * s[i].y;
      s[i].z = f[i].z + beta * s[i].z;
    }

    /* submit results? */
    resultsCounter++;
    if (resultsFreq == resultsCounter) {
      resultsCounter = 0;
      if (step_results(step, resultsFreq)) return STEP_FAILURE;
    }
  } /* end conjugate gradient loop */

  return STEP_SUCCESS;
}
