/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */


#include <string.h>
#include <math.h>
#include <stdlib.h>
#include <stdio.h>
#include "PmeKSpace.h"
#include "constant.h"
#include "helper.h"
#include "utilities.h"


static void init_exp(MD_Double *xp, MD_Int K, MD_Double fac); 

/* Computes the modulus of the discrete fourier transform of bsp_arr, */
/*  storing it into bsp_mod */
static void dftmod(MD_Double *bsp_mod, MD_Double *bsp_arr, MD_Int nfft,
		   MD_Int order) {
  MD_Int j, k;
  MD_Double twopi = 2.0 * Pi, arg, sum1, sum2;
  MD_Double infft = 1.0/nfft;

  for (k = 0; k <nfft; ++k) {
    sum1 = 0.;
    sum2 = 0.;
    for (j = 0; j < order; ++j) {
      arg = twopi * k * j * infft;
      sum1 += bsp_arr[j] * cos(arg);
      sum2 += bsp_arr[j] * sin(arg);
    }
    bsp_mod[k] = sum1*sum1 + sum2*sum2;
  }
}


static void compute_b_moduli(MD_Double *bm, MD_Int K, MD_Int order) 
{
  MD_Int i;
  MD_Double fr[3] = {0.0};
  /* since this function is called only once, it is not bad to do this */
  MD_Double *M = my_calloc((size_t)3*order, sizeof(MD_Double), "M");
  MD_Double *dM = my_calloc((size_t)3*order, sizeof(MD_Double), "dM");
  MD_Double *scratch = my_calloc((size_t)K, sizeof(MD_Double), "scratch");

  compute_bspline1(fr,M,dM,order);  
  for (i=0; i<order; i++) bm[i] = M[i];
  for (i=order; i<K; i++) bm[i] = 0.0;
  dftmod(scratch, bm, K, order);
  for (i=0; i<K; i++) bm[i] = 1.0/scratch[i];

  free(scratch);
  free(dM);
  free(M);

  return;
}


MD_Errcode pmekspace_init(struct PmeKSpace_Tag *pks, struct PmeGrid_Tag *grid)
{
  const MD_Int K1 = grid->K1;
  const MD_Int K2 = grid->K2;
  const MD_Int K3 = grid->K3;
  const MD_Int order = grid->order;
  pks->myGrid = *grid;

  pks->bm1 = my_calloc((size_t)K1, sizeof(MD_Double), "bm1");
  pks->bm2 = my_calloc((size_t)K2, sizeof(MD_Double), "bm2");
  pks->bm3 = my_calloc((size_t)K3, sizeof(MD_Double), "bm3");
  pks->exp1 = my_calloc((size_t)(K1/2+1), sizeof(MD_Double), "exp1");
  pks->exp2 = my_calloc((size_t)(K2/2+1), sizeof(MD_Double), "exp2");
  pks->exp3 = my_calloc((size_t)(K3/2+1), sizeof(MD_Double), "exp3");

  compute_b_moduli(pks->bm1, K1, order);
  compute_b_moduli(pks->bm2, K2, order);
  compute_b_moduli(pks->bm3, K3, order);

#ifdef DEBUG_PME
  /*
  outputv(pks->bm1, K1, "bm1");
  outputv(pks->bm2, K2, "bm2");
  outputv(pks->bm3, K3, "bm3");
  */
#endif

  return OK;
}


MD_Errcode pmekspace_destroy(struct PmeKSpace_Tag *pks)
{
  free(pks->bm1);
  free(pks->bm2);
  free(pks->bm3);
  free(pks->exp1);
  free(pks->exp2);
  free(pks->exp3);
  memset(pks, 0, sizeof *pks);

  return OK;
}



void pmekspace_charge_calc_enrgvir(struct PmeKSpace_Tag *pks, 
				   MD_Double *q_arr, 
				   struct Lattice_Tag lattice,
				   const MD_Double ewaldcof, 
				   MD_Double *Energy, 
				   MD_Double virial[9])
{
  MD_Double energy = 0.0;
  MD_Double *exp1 = pks->exp1;
  MD_Double *exp2 = pks->exp2;
  MD_Double *exp3 = pks->exp3;
  const MD_Double *bm1 = pks->bm1;
  const MD_Double *bm2 = pks->bm2;
  const MD_Double *bm3 = pks->bm3;
  const MD_Int K1 = pks->myGrid.K1;
  const MD_Int K2 = pks->myGrid.K2;
  const MD_Int K3 = pks->myGrid.K3;
  MD_Int k1, k2, k3, ind;
  const MD_Double i_pi_volume = 1.0/(Pi * lattice_volume(&lattice));
  const MD_Double piob = (Pi * Pi) / (ewaldcof * ewaldcof);

  for (ind=0; ind<6; ind++) virial[ind] = 0.0;

  if ( lattice_is_orthogonal(&lattice) ) { /* the only case we consider */
    MD_Double recipx = lattice.b1.x;
    MD_Double recipy = lattice.b2.y;
    MD_Double recipz = lattice.b3.z;
    init_exp(exp1, K1, -recipx*recipx*piob);
    init_exp(exp2, K2, -recipy*recipy*piob);
    init_exp(exp3, K3, -recipz*recipz*piob);
    ind = 0;
    for ( k1=0; k1<K1; ++k1 ) {
      MD_Double m1, m11, b1, xp1;
      MD_Int k1_s = k1<=K1/2 ? k1 : k1-K1;
      b1 = bm1[k1];
      m1 = k1_s*recipx;
      m11 = m1*m1;
      xp1 = i_pi_volume*exp1[abs(k1_s)];
      for ( k2=0; k2<K2; ++k2 ) {
        MD_Double m2, m22, b1b2, xp2;
        MD_Int k2_s = k2<=K2/2 ? k2 : k2-K2;
        b1b2 = b1*bm2[k2];
        m2 = k2_s*recipy;
        m22 = m2*m2;
        xp2 = exp2[abs(k2_s)]*xp1;
        if ( k1==0 && k2==0 ) {
          q_arr[ind++] = 0.0;
          q_arr[ind++] = 0.0;
          k3 = 1;
        } else {
          k3 = 0;
        }
        for ( ; k3<=K3/2; ++k3 ) {
          MD_Double m3, m33, xp3, msq, imsq, vir, fac;
          MD_Double theta3, theta, q2, qr, qc, C;
          theta3 = bm3[k3] *b1b2;
          m3 = k3*recipz;
          m33 = m3*m3;
          xp3 = exp3[k3];
          qr = q_arr[ind]; qc=q_arr[ind+1];
          q2 = (qr*qr + qc*qc)*theta3;
          if ( (k3 == 0) || ( k3 == K3/2 && ! (K3 & 1) ) ) q2 *= 0.5;
          msq = m11 + m22 + m33;
          imsq = 1.0/msq;
          C = xp2*xp3*imsq;
          theta = theta3*C;
          q_arr[ind++] *= theta;
          q_arr[ind++] *= theta;
          vir = -2.0*(piob+imsq);
          fac = q2*C;
          energy += fac;
          virial[0] += fac*(1.0+vir*m11);
          virial[1] += fac*vir*m1*m2;
          virial[2] += fac*vir*m1*m3;
          virial[3] += fac*(1.0+vir*m22);
          virial[4] += fac*vir*m2*m3;
          virial[5] += fac*(1.0+vir*m33);
        }
      }
    }
  } else {
    fprintf(stderr, "lattice is not orthogonal, cannot handle it yet.\n");
    exit(1);
  }

  *Energy = energy;

  return;
}


void init_exp(MD_Double *xp, MD_Int K, MD_Double fac) 
{
  MD_Int i;
  for (i=0; i<= K/2; i++) xp[i] = exp(i*i*fac);
} 

/*
 *
 * dipole functions
 *
 */
void pmekspace_init_exp(struct PmeKSpace_Tag* pks, 
			struct Lattice_Tag lattice,
			const MD_Double ewaldcof)
{
  MD_Double *exp1 = pks->exp1;
  MD_Double *exp2 = pks->exp2;
  MD_Double *exp3 = pks->exp3;
  const MD_Int K1 = pks->myGrid.K1;
  const MD_Int K2 = pks->myGrid.K2;
  const MD_Int K3 = pks->myGrid.K3;
  const MD_Double piob = (Pi * Pi) / (ewaldcof * ewaldcof);

  if ( lattice_is_orthogonal(&lattice) ) {
    MD_Double recipx = lattice.b1.x;
    MD_Double recipy = lattice.b2.y;
    MD_Double recipz = lattice.b3.z;
    init_exp(exp1, K1, -recipx*recipx*piob);
    init_exp(exp2, K2, -recipy*recipy*piob);
    init_exp(exp3, K3, -recipz*recipz*piob);
    /*
    outputv(exp1, K1/2+1, "exp1");
    outputv(exp2, K2/2+1, "exp2");
    outputv(exp3, K3/2+1, "exp3");
    */
  } else {
    fprintf(stderr, "cannot not handle unorthogonal lattice yet.\n");
  }

}


void pmekspace_apply_potential(const struct PmeKSpace_Tag *pks, 
			       const struct Lattice_Tag lattice,
			       MD_Double *s_arr)
{
  const MD_Double *exp1 = pks->exp1;
  const MD_Double *exp2 = pks->exp2;
  const MD_Double *exp3 = pks->exp3;
  const MD_Double *bm1 = pks->bm1;
  const MD_Double *bm2 = pks->bm2;
  const MD_Double *bm3 = pks->bm3;
  const MD_Int K1 = pks->myGrid.K1;
  const MD_Int K2 = pks->myGrid.K2;
  const MD_Int K3 = pks->myGrid.K3;
  MD_Int k1, k2, k3, ind;
  const MD_Double i_pi_volume = 1.0/(Pi * lattice_volume(&lattice));
  MD_Double m1, m11, m2, m22, m3, m33; 
  MD_Double bxp1, bxp2, poten;
  MD_Int k1_s, k2_s;

  if ( lattice_is_orthogonal(&lattice) ) { /* the only case we consider */
    MD_Double recipx = lattice.b1.x;
    MD_Double recipy = lattice.b2.y;
    MD_Double recipz = lattice.b3.z;
    ind = 0;
    for (k1 = 0; k1 < K1; k1++) {
      k1_s = k1<=K1/2 ? k1 : k1-K1;
      m1  = k1_s*recipx;
      m11 = m1*m1;
      bxp1= bm1[k1]*i_pi_volume*exp1[abs(k1_s)];
      for (k2 = 0; k2 < K2; k2++) {
        k2_s = k2<=K2/2 ? k2 : k2-K2;
        m2  = k2_s*recipy;
        m22 = m11 + m2*m2;
	bxp2= bxp1 * bm2[k2] * exp2[abs(k2_s)];
        if (0==k1 && 0==k2) {    /* see note 2004-09-08 */
          s_arr[ind++] = 0.0;    /* it should be zero, for s=Ih1*d, or */
          s_arr[ind++] = 0.0;    /* s=Ih0*q, and  sum(q) = 0 */
          k3 = 1;
        } else {
          k3 = 0;
        }
        for ( ; k3 <= K3/2; k3++) {
          m3 = k3*recipz;
          m33 = m22 + m3*m3;
          poten = bxp2 * bm3[k3] * exp3[k3] / m33;
          s_arr[ind++] *= poten;
          s_arr[ind++] *= poten;
        }
      }
    }
  } else {
    fprintf(stderr, "lattice is not orthogonal, cannot handle it yet.\n");
    exit(1);
  }

  return;
}


void pmekspace_dipole_calc_enrgvir(struct PmeKSpace_Tag *pks, 
				   MD_Double *d_arr, 
				   struct Lattice_Tag lattice,
				   MD_Double ewaldcof,
				   MD_Double *Energy, 
				   MD_Double virial[9])
{
  pmekspace_charge_calc_enrgvir(pks, d_arr, lattice, ewaldcof, Energy, virial);
  printf("the extra term in virial has not implemented yet\n");
/*
  exit(1);
*/
  return;
}
