/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * mgrid.c
 */

#include <string.h>
#include <math.h>
#include "mgrid/mgrid.h"
#include "debug/debug.h"

/* tolerance for floating point equality */
#undef EQTOL
#define EQTOL  1e-10

/* prototypes of internal routines */
static int32 maxlevels(const MgridParam *p);


#define NELEMS(x)  (sizeof(x)/sizeof(x[0]))


/* order in tables must agree with enum from mgrid.h */
static const char *MgridBoundaryTable[] = {
  "unknown",
  "PERIODIC",
  "NONPERIODIC",
};

static const char *MgridApproxTable[] = {
  "unknown",
  "CUBIC",
  "BSPLINE",
  "QUINTIC1",
  "QUINTIC2",
  "HEPTIC1",
  "HEPTIC3",
  "NONIC1",
  "NONIC4",
  "HERMITE",
};

static const char *MgridSplitTable[] = {
  "unknown",
  "TAYLOR1",
  "TAYLOR2",
  "TAYLOR3",
  "TAYLOR4",
  "TAYLOR5",
  "TAYLOR6",
  "TAYLOR7",
  "TAYLOR8",
  "ERRMIN3",
};

/* convert between strings and type numbers */
int mgrid_string_to_boundary(const char *s)
{
  int k;

  for (k = 1;  k < NELEMS(MgridBoundaryTable);  k++) {
    if (strcasecmp(s, MgridBoundaryTable[k]) == 0) return k;
  }
  return 0;  /* unknown */
}

const char *mgrid_boundary_to_string(int k)
{
  if (k < 0 || k >= NELEMS(MgridBoundaryTable)) k = 0;
  return MgridBoundaryTable[k];
}

int mgrid_string_to_approx(const char *s)
{
  int k;

  for (k = 1;  k < NELEMS(MgridApproxTable);  k++) {
    if (strcasecmp(s, MgridApproxTable[k]) == 0) return k;
  }
  return 0;  /* unknown */
}

const char *mgrid_approx_to_string(int k)
{
  if (k < 0 || k >= NELEMS(MgridApproxTable)) k = 0;
  return MgridApproxTable[k];
}

int mgrid_string_to_split(const char *s)
{
  int k;

  for (k = 1;  k < NELEMS(MgridSplitTable);  k++) {
    if (strcasecmp(s, MgridSplitTable[k]) == 0) return k;
  }
  return 0;  /* unknown */
}

const char *mgrid_split_to_string(int k)
{
  if (k < 0 || k >= NELEMS(MgridSplitTable)) k = 0;
  return MgridSplitTable[k];
}


/* help user choose suitable params */
int mgrid_param_config(MgridParam *mgp)
{
  MgridParam p = *mgp;
  int32 max;

  /* check correctness */
  if (p.length < 0 || p.cutoff < 0 || p.spacing < 0
      || p.nspacings < 0 || p.nlevels < 0) {
    return MGRID_FAIL;
  }

  /* cannot set defaults for natoms and boundary */
  if (p.natoms <= 0) return MGRID_FAIL;
  if (p.boundary != MGRID_PERIODIC && p.boundary != MGRID_NONPERIODIC) {
    return MGRID_FAIL;
  }

  /* set default approximation if none set */
  if (p.approx == 0) {
    p.approx = MGRID_CUBIC;
  }
  else if (p.approx <= MGRID_APPROX_BEGIN || p.approx >= MGRID_APPROX_END) {
    return MGRID_FAIL;
  }

  /* set default splitting if none set */
  if (p.split == 0) {
    p.split = MGRID_TAYLOR2;
  }
  else if (p.split <= MGRID_SPLIT_BEGIN || p.split >= MGRID_SPLIT_END) {
    return MGRID_FAIL;
  }

  /* default parameters differ depending on boundary */
  if (p.boundary == MGRID_PERIODIC) {

    /* must have positive domain length */
    if (p.length == 0) return MGRID_FAIL;

    if (p.nspacings == 0 && p.spacing == 0) {
      /* set based on # atoms - same as 2^k, for k=log(N)/(3*log(2)) */
      p.nspacings = 1 << ((int32) (log(p.natoms) / (3 * log(2))));
      p.spacing = p.length / p.nspacings;
    }
    else if (p.spacing > 0) {
      int32 n = 1;
      /* number of spacings must be power of 2 and not greater than given */
      p.nspacings = (int32) ceil(p.length / p.spacing);
      ASSERT(p.nspacings >= 1);
      while (p.nspacings > 1) { p.nspacings >>= 1;  n <<= 1; }
      p.nspacings = n;
      p.spacing = p.length / p.nspacings;
    }
    else {
      int32 n = 1;
      /* number of spacings must be power of 2 and not greater than given */
      ASSERT(p.nspacings >= 1);
      while (p.nspacings > 1) { p.nspacings >>= 1;  n <<= 1; }
      p.nspacings = n;
      p.spacing = p.length / p.nspacings;
    }
    ASSERT(fabs(p.nspacings * p.spacing - p.length) <= EQTOL);
  }
  else {
    ASSERT(p.boundary == MGRID_NONPERIODIC);

    if (p.length == 0 && p.spacing == 0) return MGRID_FAIL;

    if (p.length == 0 && p.nspacings == 0) {
      ASSERT(p.spacing > 0);
      p.nspacings = (int32) pow(p.natoms, 1./3);
      ASSERT(p.nspacings > 0);
      /*
       * warning: user must maintain atoms within cell defined
       *   by center and length
       */
      p.length = p.nspacings * p.spacing;
    }
    else if (p.length == 0) {
      ASSERT(p.spacing > 0);
      ASSERT(p.nspacings > 0);
      /*
       * warning: user must maintain atoms within cell defined
       *   by center and length
       */
      p.length = p.nspacings * p.spacing;
    }
    else if (p.nspacings == 0 && p.spacing == 0) {
      p.nspacings = (int32) pow(p.natoms, 1./3);
      ASSERT(p.nspacings > 0);
      p.spacing = p.length / p.nspacings;
    }
    else if (p.spacing > 0) {
      p.nspacings = (int32) ceil(p.length / p.spacing);
      ASSERT(p.nspacings > 0);
      ASSERT(p.length <= p.nspacings * p.spacing + EQTOL);
      /* might expand length a little */
      p.length = p.nspacings * p.spacing;
    }
    else {
      ASSERT(p.length > 0);
      ASSERT(p.nspacings > 0);
      p.spacing = p.length / p.nspacings;
    }
    ASSERT(fabs(p.nspacings * p.spacing - p.length) <= EQTOL);
  }

  /* set number of levels in grid hierarchy */
  max = maxlevels(&p);
  if (max == 0) return MGRID_FAIL;
  ASSERT(max > 0);
  if (p.nlevels == 0 || p.nlevels > max) {
    p.nlevels = max;
  }

  /* update user's params */
  *mgp = p;
  return 0;
}


/* methods for Mgrid */
int mgrid_init(Mgrid *mg)
{
  memset(mg, 0, sizeof(Mgrid));
  return 0;
}


void mgrid_done(Mgrid *mg)
{
  if (mg->long_done != NULL) mg->long_done(mg);
  if (mg->short_done != NULL) mg->short_done(mg);
}


int mgrid_setup(Mgrid *mg, const MgridSystem *sys, const MgridParam *p)
{
  int32 max;

  /* validate choices for boundary, approximation, and splitting */
  if ((p->boundary != MGRID_PERIODIC
        && p->boundary != MGRID_NONPERIODIC)
      || (p->approx <= MGRID_APPROX_BEGIN || p->approx >= MGRID_APPROX_END)
      || (p->split <= MGRID_SPLIT_BEGIN || p->split >= MGRID_SPLIT_END)) {
    return MGRID_FAIL;
  }

  /* determine max levels */
  max = maxlevels(p);

  /* validate numeric params */
  if (p->cutoff <= 0
      || p->spacing <= 0
      || p->length <= 0
      || (p->length <= p->cutoff && p->boundary == MGRID_PERIODIC)
      || p->nspacings <= 0
      || p->nlevels <= 0
      || p->nlevels > max  /* fails for periodic if nspacings not power of 2 */
      || p->natoms <= 0
      /* one of length or spacing should be set from other two quantities */
      || (fabs(p->nspacings * p->spacing - p->length) > EQTOL)) {
    return MGRID_FAIL;
  }

  /* params are valid, store them */
  mg->param = *p;

  /* determine whether or not splitting has even powers of r/a only */
  switch (p->split) {
    case MGRID_TAYLOR1:
    case MGRID_TAYLOR2:
    case MGRID_TAYLOR3:
    case MGRID_TAYLOR4:
    case MGRID_TAYLOR5:
    case MGRID_TAYLOR6:
    case MGRID_TAYLOR7:
    case MGRID_TAYLOR8:
    case MGRID_ERRMIN3:
    case MGRID_EXSELF1:
    case MGRID_EXSELF2:
    case MGRID_EXSELF3:
    case MGRID_EXSELF7:  /* fall thru for all even powers cases */
      mg->is_split_even_powers = 1;
        /* have even powers of r/a in splitting function */
      break;
    case MGRID_ODDPR1:
    case MGRID_ODDPR2:
    case MGRID_ODDPR3:
    case MGRID_ODDPR4:
    case MGRID_ODDPR5:
    case MGRID_ODDPR6:
    case MGRID_ODDPR7:
    case MGRID_ODDPR8:  /* fall thru for cases that have some odd power */
      mg->is_split_even_powers = 0;
      break;
    default:
      return MGRID_FAIL;  /* others not yet supported */
  }

  /* setup short range part */
  if (mgrid_short_setup(mg)) return MGRID_FAIL;

  /* setup long range part */
  if (p->boundary == MGRID_NONPERIODIC) {
    switch (p->approx) {
      case MGRID_CUBIC:
        if (mgrid_ncubic_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_BSPLINE:
        if (mgrid_nbspline_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_QUINTIC1:
        if (mgrid_nquintic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_QUINTIC2:
        if (mgrid_nquintic2_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HEPTIC1:
        if (mgrid_nheptic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HEPTIC3:
        if (mgrid_nheptic3_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_NONIC1:
        if (mgrid_nnonic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_NONIC4:
        if (mgrid_nnonic4_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HERMITE:
        if (mgrid_nhermite_setup(mg, sys)) return MGRID_FAIL;
        break;
      default:
        return MGRID_FAIL;  /* others not yet supported */
    }
  }
  else {
    if (p->nlevels != max) return MGRID_FAIL;
    switch (p->approx) {
      case MGRID_CUBIC:
        if (mgrid_pcubic_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_BSPLINE:
        if (mgrid_pbspline_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_QUINTIC1:
        if (mgrid_pquintic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_QUINTIC2:
        if (mgrid_pquintic2_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HEPTIC1:
        if (mgrid_pheptic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HEPTIC3:
        if (mgrid_pheptic3_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_NONIC1:
        if (mgrid_pnonic1_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_NONIC4:
        if (mgrid_pnonic4_setup(mg, sys)) return MGRID_FAIL;
        break;
      case MGRID_HERMITE:
        if (mgrid_phermite_setup(mg, sys)) return MGRID_FAIL;
        break;
      default:
        return MGRID_FAIL;  /* others not yet supported */
    }
  }

  return 0;
}


int mgrid_force(Mgrid *mg, MgridSystem *sys)
{
  const int32 natoms = mg->param.natoms;

  /* clear system output data storage */
  sys->u_elec = 0;
  sys->u_short = 0;
  sys->u_long = 0;
  memset(sys->f_elec, 0, natoms * sizeof(MD_Dvec));
  if (sys->f_short) memset(sys->f_short, 0, natoms * sizeof(MD_Dvec));
  if (sys->f_long) memset(sys->f_long, 0, natoms * sizeof(MD_Dvec));

  /* compute short range and long range forces */
  /* (must compute short range part FIRST) */
  if (mg->short_force(mg, sys)) return MGRID_FAIL;
  if (mg->long_force(mg, sys)) return MGRID_FAIL;
  return 0;
}


/*
 * return natoms if all atom positions are within bounding cell
 * otherwise return k < natoms index of first atom found outside of cell
 */
int mgrid_system_validate(const Mgrid *mg, const MgridSystem *sys)
{
  const MD_Dvec center = mg->param.center;
  const double length = mg->param.length;
  MD_Dvec lo, hi;
  const MD_Dvec *pos = sys->pos;
  const int32 natoms = mg->param.natoms;
  int32 k;

  lo.x = center.x - 0.5 * length;
  lo.y = center.y - 0.5 * length;
  lo.z = center.z - 0.5 * length;
  hi.x = center.x + 0.5 * length;
  hi.y = center.y + 0.5 * length;
  hi.z = center.z + 0.5 * length;
  if (mg->param.boundary == MGRID_NONPERIODIC) {
    for (k = 0;  k < natoms;  k++) {
      if (pos[k].x <= lo.x || pos[k].x >= hi.x
          || pos[k].y <= lo.y || pos[k].y >= hi.y
          || pos[k].z <= lo.z || pos[k].z >= hi.z) return k;
    }
  }
  else {
    for (k = 0;  k < natoms;  k++) {
      if (pos[k].x < lo.x || pos[k].x >= hi.x
          || pos[k].y < lo.y || pos[k].y >= hi.y
          || pos[k].z < lo.z || pos[k].z >= hi.z) return k;
    }
  }
  return natoms;
}


/*
 * return max levels for given set of parameters, 0 for error
 */
int32 maxlevels(const MgridParam *p)
{
  const int32 bound = p->boundary;
  const int32 approx = p->approx;
  int32 n = p->nspacings;
  int32 na, nb;
  int32 k = 1;

  /* nspacings must be positive */
  if (n <= 0) return 0;

  if (bound == MGRID_PERIODIC) {
    /* this also verifies that nspacings is a power of 2 */
    while (n > 1) {
      if (n & 1) return 0;    /* error */
      n >>= 1;
      k++;
    }
  }
  else if (approx == MGRID_CUBIC || approx == MGRID_BSPLINE) {
    na = -1;
    nb = n + 1;
    while (nb - na + 1 > 6) {
      na = -((-na+1)/2) - 1;
      nb = (nb+1)/2 + 1;
      k++;
    }
  }
  else if (approx == MGRID_QUINTIC1 || approx == MGRID_QUINTIC2) {
    na = -2;
    nb = n + 2;
    while (nb - na + 1 > 10) {
      na = -((-na+1)/2) - 2;
      nb = (nb+1)/2 + 2;
      k++;
    }
  }
  else if (approx == MGRID_HEPTIC1 || approx == MGRID_HEPTIC3) {
    na = -3;
    nb = n + 3;
    while (nb - na + 1 > 14) {
      na = -((-na+1)/2) - 3;
      nb = (nb+1)/2 + 3;
      k++;
    }
  }
  else if (approx == MGRID_NONIC1 || approx == MGRID_NONIC4) {
    na = -4;
    nb = n + 4;
    while (nb - na + 1 > 18) {
      na = -((-na+1)/2) - 4;
      nb = (nb+1)/2 + 4;
      k++;
    }
  }
  else if (approx == MGRID_HERMITE) {
    na = 0;
    nb = n;
    while (nb - na + 1 > 2) {
      na = -((-na+1)/2);
      nb = (nb+1)/2;
      k++;
    }
  }
  else {
    BUG("invalid parameters");
  }
  return k;
}
