/*
 * Copyright (C) 2004-2005 by David J. Hardy.  All rights reserved.
 *
 * setup.c
 *
 * setup for long-range part:  allocate grids and compute constants
 */

#include <stdlib.h>
#include <math.h>
#include "mgrid/split.h"
#include "debug/debug.h"


static int setup_nonperiodic(Mgrid *, const MgridSystem *);
static int setup_periodic(Mgrid *, const MgridSystem *);


int mgrid_setup_longrange(Mgrid *mg, const MgridSystem *sys)
{
  if (mg->param.boundary == MGRID_NONPERIODIC)
    return setup_nonperiodic(mg, sys);
  else {
    return setup_periodic(mg, sys);
  }
}


int setup_nonperiodic(Mgrid *mg, const MgridSystem *sys)
{
  const double *q = sys->charge;
  const int32 nlevels = mg->param.nlevels;
  const int32 natoms = mg->param.natoms;
  const int32 split = mg->param.split;
  double *gd;
  double r2, s, t, gs, gt;
  int32 i, j, k, n, na, nb;
  int32 nu;  /* actually set to (nu/2 - 1) from thesis */
  int32 omega_dim;
  int32 gdim, index;
  const double a = mg->param.cutoff;
  const double h = mg->param.spacing;
  const double h_1 = 1./h;

  /* set inverse of grid spacing */
  mg->inv_spacing = h_1;

  /* set location of multi-grids origin with lattice label (0,0,0) */
  mg->origin.x = mg->param.center.x - 0.5 * mg->param.length;
  mg->origin.y = mg->param.center.y - 0.5 * mg->param.length;
  mg->origin.z = mg->param.center.z - 0.5 * mg->param.length;

#ifdef DEBUG_WATCH
  printf("origin = (%g, %g, %g)\n", mg->origin.x, mg->origin.y, mg->origin.z);
#endif

  /* allocate arrays of charge and potential lattices */
  mg->qgrid = (MgridLattice *) malloc(nlevels * sizeof(MgridLattice));
  mg->egrid = (MgridLattice *) malloc(nlevels * sizeof(MgridLattice));
  if (mg->qgrid == NULL || mg->egrid == NULL) return MGRID_FAIL;

  /* setup charge and potential lattices for each level */
  for (k = 0;  k < nlevels;  k++) {
    if (mgrid_lattice_init(&(mg->qgrid[k]))
        || mgrid_lattice_init(&(mg->egrid[k]))) {
      return MGRID_FAIL;
    }
  }

  switch (mg->param.approx) {
    case MGRID_CUBIC:
    case MGRID_BSPLINE:
      nu = 1;
      na = -nu;
      nb = mg->param.nspacings + nu;
      omega_dim = 6;
      break;
    case MGRID_QUINTIC1:
    case MGRID_QUINTIC2:
      nu = 2;
      na = -nu;
      nb = mg->param.nspacings + nu;
      omega_dim = 10;
      break;
    case MGRID_HEPTIC1:
    case MGRID_HEPTIC3:
      nu = 3;
      na = -nu;
      nb = mg->param.nspacings + nu;
      omega_dim = 14;
      break;
    case MGRID_NONIC1:
    case MGRID_NONIC4:
      nu = 4;
      na = -nu;
      nb = mg->param.nspacings + nu;
      omega_dim = 18;
      break;
    default:
      BUG("cannot setup this approximation");
  }
  for (k = 0;  k < nlevels;  k++) {
    if (nb - na + 1 <= omega_dim && k < nlevels - 1) {
      return MGRID_FAIL;
    }
    if (mgrid_lattice_setup(&(mg->qgrid[k]), sizeof(double),
          na, nb, na, nb, na, nb)
        || mgrid_lattice_setup(&(mg->egrid[k]), sizeof(double),
          na, nb, na, nb, na, nb)) {
      return MGRID_FAIL;
    }
#ifdef DEBUG_WATCH
    printf("level:  k = %d   na = %d   nb = %d\n", k, na, nb);
#endif
    na = -((-na+1)/2) - nu;
    nb = (nb+1)/2 + nu;
  }

  /* allocate and compute scalings for each direct sum on each level */
  mg->scaling = (double *) malloc(nlevels * sizeof(double));
  if (mg->scaling == NULL) return MGRID_FAIL;
  mg->scaling[0] = 1.0;
  for (k = 1;  k < nlevels;  k++) {
    mg->scaling[k] = 0.5 * mg->scaling[k-1];
  }

  /* allocate space for direct sum weights */
  n = (int32) (2.0 * a * h_1);
  mg->gdsum_radius = n;
  if (mgrid_lattice_setup(&(mg->gdsum), sizeof(double), -n, n, -n, n, -n, n)) {
    return MGRID_FAIL;
  }

  /* compute direct sum weights */
  if (mgrid_lattice_zero(&(mg->gdsum))) {
    return MGRID_FAIL;
  }
  gdim = mg->gdsum.ni;
  ASSERT(gdim == 2 * n + 1);
  gd = (double *)(mg->gdsum.data);
  for (k = -n;  k <= n;  k++) {
    for (j = -n;  j <= n;  j++) {
      for (i = -n;  i <= n;  i++) {
        r2 = (i*i +  j*j + k*k) * h*h;
        if (r2 >= 4*a*a) continue;
        if (mg->is_split_even_powers) {
          s = r2 / (a*a);
          t = r2 / (4*a*a);
          ASSERT(t <= s);
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->gdsum), i, j, k));
          if (s <= 1) gamma(&gs, s, split);
          else gs = 1/sqrt(s);
          if (t < 1) {
            gamma(&gt, t, split);
            gd[index] = gs/a - gt/(2*a);
          }
          else {
            /* then gt=1/sqrt(t), which implies that gs/a - gt/(2*a) == 0 */
            gd[index] = 0;
          }
        }
        else {
          s = sqrt(r2 / (a*a));
          t = 0.5 * s;
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->gdsum), i, j, k));
          if (s <= 1) gamma_odd(&gs, s, split);
          else gs = 1./s;
          if (t < 1) {
            gamma_odd(&gt, t, split);
            gd[index] = gs/a - gt/(2*a);
          }
          else {
            gd[index] = 0;
          }
        }
      }
    }
  }

  /* allocate space for direct sum weights on last level */
  n = mg->qgrid[nlevels-1].ni - 1;
  mg->glast_radius = n;
  if (mgrid_lattice_setup((&mg->glast), sizeof(double), -n, n, -n, n, -n, n)) {
    return MGRID_FAIL;
  }

  /* compute direct sum weights on last level */
  if (mgrid_lattice_zero(&(mg->glast))) {
    return MGRID_FAIL;
  }
  gdim = mg->glast.ni;
  ASSERT(gdim == 2 * n + 1);
  gd = (double *)(mg->glast.data);
  for (k = -n;  k <= n;  k++) {
    for (j = -n;  j <= n;  j++) {
      for (i = -n;  i <= n;  i++) {
        r2 = (i*i + j*j + k*k) * h*h;
        if (mg->is_split_even_powers) {
          s = r2 / (a*a);
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->glast), i, j, k));
          /* unlike previous, this sum is not truncated */
          if (s <= 1) gamma(&gs, s, split);
          else gs = 1/sqrt(s);
          gd[index] = gs/a;
        }
        else {
          s = sqrt(r2 / (a*a));
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->glast), i, j, k));
          if (s <= 1) gamma_odd(&gs, s, split);
          else gs = 1./s;
          gd[index] = gs/a;
        }
      }
    }
  }

  /* compute g_a(0) */
  s = 0;
  if (mg->is_split_even_powers) {
    gamma(&gs, s, split);
  }
  else {
    gamma_odd(&gs, s, split);
  }
  mg->g_zero = gs/a;

  /* compute self potential */
  s = 0;
  for (n = 0;  n < natoms;  n++) {
    s += q[n] * q[n];
  }
  mg->u_self = 0.5 * mg->g_zero * s;

  return 0;
}


int setup_periodic(Mgrid *mg, const MgridSystem *sys)
{
  const double *q = sys->charge;

  /* omit top level when not debugging */
#ifdef DEBUG_SUPPORT
  const int32 nlevels = mg->param.nlevels;
#else
  const int32 nlevels = mg->param.nlevels - 1;
#endif

  const int32 natoms = mg->param.natoms;
  const int32 split = mg->param.split;
  double *gd;
  double r2, s, t, gs, gt;
  int32 i, j, k, n;
  int32 gdim, index;
  const double a = mg->param.cutoff;
  const double h = mg->param.spacing;
  const double h_1 = 1./h;

  /* set inverse of grid spacing */
  mg->inv_spacing = h_1;

  /* set location of multi-grids origin with lattice label (0,0,0) */
  mg->origin.x = mg->param.center.x - 0.5 * mg->param.length;
  mg->origin.y = mg->param.center.y - 0.5 * mg->param.length;
  mg->origin.z = mg->param.center.z - 0.5 * mg->param.length;

#ifdef DEBUG_WATCH
  printf("origin = (%g, %g, %g)\n", mg->origin.x, mg->origin.y, mg->origin.z);
#endif

  /* allocate arrays of charge and potential lattices */
  mg->qgrid = (MgridLattice *) malloc(nlevels * sizeof(MgridLattice));
  mg->egrid = (MgridLattice *) malloc(nlevels * sizeof(MgridLattice));
  if (mg->qgrid == NULL || mg->egrid == NULL) return MGRID_FAIL;

  /* setup charge and potential lattices for each level */
  n = mg->param.nspacings;
  for (k = 0;  k < nlevels;  k++, n >>= 1) {
    if (n == 0 || ((n & 1) && n != 1)) return MGRID_FAIL;
    if (mgrid_lattice_init(&(mg->qgrid[k]))
        || mgrid_lattice_init(&(mg->egrid[k]))) {
      return MGRID_FAIL;
    }
    if (mgrid_lattice_setup(&(mg->qgrid[k]), sizeof(double),
          0, n-1, 0, n-1, 0, n-1)
        || mgrid_lattice_setup(&(mg->egrid[k]), sizeof(double),
          0, n-1, 0, n-1, 0, n-1)) {
      return MGRID_FAIL;
    }
  }
#ifdef DEBUG_SUPPORT
  ASSERT(n == 0);  /* i.e. top level has 1 grid point */
#else
  if (n != 1) return MGRID_FAIL;  /* i.e. top level omitted */
#endif

  /* allocate and compute scalings for each direct sum on each level */
  mg->scaling = (double *) malloc(nlevels * sizeof(double));
  if (mg->scaling == NULL) return MGRID_FAIL;
  mg->scaling[0] = 1.0;
  for (k = 1;  k < nlevels;  k++) {
    mg->scaling[k] = 0.5 * mg->scaling[k-1];
  }

  /* allocate space for direct sum weights */
  n = (int32) (2.0 * a * h_1);
  mg->gdsum_radius = n;
  if (mgrid_lattice_setup(&(mg->gdsum), sizeof(double), -n, n, -n, n, -n, n)) {
    return MGRID_FAIL;
  }

  /* compute direct sum weights */
  if (mgrid_lattice_zero(&(mg->gdsum))) {
    return MGRID_FAIL;
  }
  gdim = mg->gdsum.ni;
  ASSERT(gdim == 2 * n + 1);
  gd = (double *)(mg->gdsum.data);
  for (k = -n;  k <= n;  k++) {
    for (j = -n;  j <= n;  j++) {
      for (i = -n;  i <= n;  i++) {
        r2 = (i*i +  j*j + k*k) * h*h;
        if (r2 >= 4*a*a) continue;
        if (mg->is_split_even_powers) {
          s = r2 / (a*a);
          t = r2 / (4*a*a);
          ASSERT(t <= s);
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->gdsum), i, j, k));
          if (s <= 1) gamma(&gs, s, split);
          else gs = 1/sqrt(s);
          if (t < 1) {
            gamma(&gt, t, split);
            gd[index] = gs/a - gt/(2*a);
          }
          else {
            /* then gt=1/sqrt(t), which implies that gs/a - gt(2*a) == 0 */
            gd[index] = 0;
          }
        }
        else {
          s = sqrt(r2 / (a*a));
          t = 0.5 * s;
          index = (k * gdim + j) * gdim + i;
          ASSERT(&gd[index] == mgrid_lattice_elem(&(mg->gdsum), i, j, k));
          if (s <= 1) gamma_odd(&gs, s, split);
          else gs = 1./s;
          if (t < 1) {
            gamma_odd(&gt, t, split);
            gd[index] = gs/a - gt/(2*a);
          }
          else {
            gd[index] = 0;
          }
        }
      }
    }
  }

  /* compute g_a(0) */
  s = 0;
  if (mg->is_split_even_powers) {
    gamma(&gs, s, split);
  }
  else {
    gamma_odd(&gs, s, split);
  }
  mg->g_zero = gs/a;

  /* compute self potential */
  s = 0;
  for (n = 0;  n < natoms;  n++) {
    s += q[n] * q[n];
  }
  mg->u_self = 0.5 * mg->g_zero * s;

  return 0;
}
