/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/* modified from NAMD */

#include <assert.h>
#include <string.h>
#include <stdio.h>

#ifdef HAVE_FFTW
#include "fftw.h"
#include "rfftw.h"
#endif

#include "pme.h"
#include "PmeBase.h"
#include "pme_recip.h"
#include "PmeRealSpace.h"
#include "PmeKSpace.h"
#include "utilities.h"
#include "helper.h"
#include "utilities.h"
#include "constant.h"
#include "unit.h"

#define DEBUG_EPSILON 1e-14


static void scale_coordinates(MD_Dvec p[], const MD_Int N, 
			      const struct Lattice_Tag lattice, 
			      const struct PmeGrid_Tag grid);

static void remove_net_force(MD_Dvec f[], MD_Int natoms);



MD_Errcode pmerec_init(struct PmeRecip_Tag* rec, 
		       const struct PmeParams_Tag* params)
{
  struct PmeGrid_Tag* myGrid = &(rec->myGrid);
  struct PmeParams_Tag *simParams = &(rec->storePmeParams);
  MD_Int n[3], natoms; 

  fprintf(stderr, "pmerec_init\n");
  rec->storePmeParams = *params;

  myGrid->K1    = simParams->PMEGridSizeX;
  myGrid->K2    = simParams->PMEGridSizeY;
  myGrid->K3    = simParams->PMEGridSizeZ;
  myGrid->order = simParams->PMEInterpOrder;
  myGrid->dim2  = myGrid->K2;
  myGrid->dim3  = 2 * (myGrid->K3/2 + 1);   /* for rfftw */

  lattice_set(&(rec->lattice),
              simParams->cellBasisVector1, 
	      simParams->cellBasisVector2,
              simParams->cellBasisVector3, 
	      simParams->cellOrigin);

  natoms = simParams->natoms;
  rec->force = my_calloc((size_t)natoms, sizeof(MD_Dvec), "rec force");
  rec->scaled_pos = my_calloc((size_t)natoms, sizeof(MD_Dvec), "spos");
  rec->q_arr = my_calloc((size_t)myGrid->K1 * myGrid->dim2  * myGrid->dim3,
			 sizeof(MD_Double), "q_arr");

  if (simParams->has_induced_dipole) {
    rec->work_arr=my_calloc((size_t)natoms*3,sizeof(MD_Double),"work");
    rec->d_arr = my_calloc((size_t)myGrid->K1 * myGrid->dim2  * myGrid->dim3,
			   sizeof(MD_Double), "d_arr");
  } else {
    rec->work_arr=my_calloc((size_t)myGrid->K1 * myGrid->dim2 * myGrid->dim3,
			    sizeof(*rec->work_arr), "work");
  }

  rec->myRealSpace = my_calloc((size_t)1, sizeof *(rec->myRealSpace), "mRS");
  if (pmerealspace_init(rec->myRealSpace, &(rec->myGrid), natoms)) {
    fprintf(stderr, "failed to initialize myRealSpace\n");
    return MD_FAIL;
  }

  rec->myKSpace = my_calloc((size_t)1, sizeof *(rec->myKSpace), "myKSpace");
  if (pmekspace_init(rec->myKSpace, myGrid)) {
    fprintf(stderr, "failed to initilize myKSpace\n");
    return MD_FAIL;
  }

#ifdef HAVE_FFTW
  assert(sizeof(MD_Double) == sizeof (fftw_real));
#else
  NEED_FFTW(1);
#endif
  n[0] = myGrid->K1; 
  n[1] = myGrid->K2; 
  n[2] = myGrid->K3; 
  /*
  rec->forward_plan =
    rfftwnd_create_plan_specific (3, n, FFTW_REAL_TO_COMPLEX,
			FFTW_MEASURE | FFTW_IN_PLACE | FFTW_USE_WISDOM,
			rec->q_arr, 1, NULL, 0);
  rec->backward_plan = 
    rfftwnd_create_plan_specific(3, n, FFTW_COMPLEX_TO_REAL,
			FFTW_MEASURE | FFTW_IN_PLACE | FFTW_USE_WISDOM,
			rec->q_arr, 1, NULL, 0);
  */

#if 0
  /* note that we need 2 FFT arrays, so no specific */
  rec->forward_plan =
    rfftwnd_create_plan(3, n, FFTW_REAL_TO_COMPLEX,
			FFTW_MEASURE | FFTW_IN_PLACE | FFTW_USE_WISDOM);
  rec->backward_plan = 
    rfftwnd_create_plan(3, n, FFTW_COMPLEX_TO_REAL,
			FFTW_MEASURE | FFTW_IN_PLACE | FFTW_USE_WISDOM);
#endif

  /* note that we need 2 FFT arrays, so no specific */
#ifdef HAVE_FFTW
  rec->forward_plan =
    rfftwnd_create_plan(3, n, FFTW_REAL_TO_COMPLEX,
			FFTW_ESTIMATE | FFTW_IN_PLACE | FFTW_USE_WISDOM);
  rec->backward_plan = 
    rfftwnd_create_plan(3, n, FFTW_COMPLEX_TO_REAL,
			FFTW_ESTIMATE | FFTW_IN_PLACE | FFTW_USE_WISDOM);

  assert(NULL != rec->forward_plan && NULL != rec->backward_plan);
#else
  NEED_FFTW(1);
#endif

  return OK;
}


MD_Errcode pmerec_destroy(struct PmeRecip_Tag* rec)
{
  free(rec->force);
  free(rec->q_arr);
  free(rec->d_arr);
  free(rec->work_arr);
  free(rec->scaled_pos);

  if (pmerealspace_destroy(rec->myRealSpace)) {
    fprintf(stderr, "failed to destroy myRealSpace\n");
    return MD_FAIL;
  }
  free(rec->myRealSpace);

  if (pmekspace_destroy(rec->myKSpace)) {
    fprintf(stderr, "failed to destroy myKSpace\n");
    return MD_FAIL;
  }
  free(rec->myKSpace);

#ifdef HAVE_FFTW
  rfftwnd_destroy_plan(rec->backward_plan);
  rfftwnd_destroy_plan(rec->forward_plan);
#else
  NEED_FFTW(1);
#endif

  return OK;
}


#if 0  /* this is the old way, it calculates virial  */
void pmerec_charge_compute(struct PmeRecip_Tag* rec, const MD_Dvec *pos,
			   const MD_Double *charge)
{
  MD_Int natoms = rec->storePmeParams.natoms;

  memcpy(rec->scaled_pos, pos, (size_t)natoms*sizeof(*pos));
  scale_coordinates(rec->scaled_pos, natoms, rec->lattice, rec->myGrid);
  pmerealspace_fill_bspline1(rec->myRealSpace, rec->scaled_pos);
  pmerealspace_fill_charges(rec->myRealSpace, rec->scaled_pos, charge, 
			    rec->q_arr);

  /* plan, howmany, in, istride, idist, out ,ostridt, odist */
  rfftwnd_real_to_complex(rec->forward_plan, 1, rec->q_arr, 1, 1, NULL, 0, 0);
  /* last 3 parameters are ignored for in-place fftw */

  pmekspace_charge_calc_enrgvir(rec->myKSpace, rec->q_arr, rec->lattice, 
				rec->storePmeParams.ewaldcof,
				&(rec->potential), rec->virial);
  printf("rec energy = %f\n", rec->potential);
  /* plan, howmany, in, istride, idist, out ,ostridt, odist */
  rfftwnd_complex_to_real(rec->backward_plan, 1, 
			  (fftw_complex *) rec->q_arr, 1, 1, NULL, 0, 0);

  pmerealspace_charge_calc_force(rec->myRealSpace, rec->q_arr, 
				 rec->scaled_pos, charge, rec->lattice,
				 rec->force);

#ifdef PME_CONSERVE_ENERGY 
    { static int firstime = 1;
      if (firstime) {
	printf("do not substract out the extra force in PME !\n");
	firstime = 0;
      }
    }
#else
  remove_net_force(rec->force, natoms);
#endif

  return;
}
#else  /* this is the new way, it doesnot calculate virial */
void pmerec_charge_compute(struct PmeRecip_Tag* rec, const MD_Dvec* pos,
			   const MD_Double* charge)
{
  MD_Int natoms = rec->storePmeParams.natoms;
  MD_Int gridlen = rec->myGrid.K1 * rec->myGrid.dim2 * rec->myGrid.dim3;
  MD_Dvec* spos = rec->scaled_pos;
  struct PmeRealSpace_Tag* realspace = rec->myRealSpace;
#ifdef HAVE_FFTW
  struct PmeKSpace_Tag* kspace = rec->myKSpace;
#endif

  memcpy(spos, pos, (size_t)natoms*sizeof(*pos));
  scale_coordinates(spos, natoms, rec->lattice, rec->myGrid);
  pmerealspace_fill_bspline1(realspace, spos);
  pmerealspace_fill_charges(realspace, spos, charge, rec->q_arr);

  memcpy(rec->work_arr, rec->q_arr, gridlen*sizeof(*rec->work_arr));
#ifdef HAVE_FFTW
  rfftwnd_one_real_to_complex(rec->forward_plan, rec->q_arr, NULL);
  pmekspace_init_exp(kspace, rec->lattice, rec->storePmeParams.ewaldcof);
  pmekspace_apply_potential(kspace, rec->lattice, rec->q_arr);
  rfftwnd_one_complex_to_real(rec->backward_plan, (fftw_complex *) rec->q_arr, 
			      NULL);
#else
  NEED_FFTW(1);
#endif

  /* depends on the fact that q_arr=0 at the patching positions */
  rec->potential = 0.5 * DOT(rec->work_arr, rec->q_arr, gridlen);
#ifdef DEBUG_PME
  printf("recE = %20.15f\n",rec->potential);
#endif

  pmerealspace_charge_calc_force(realspace, rec->q_arr, spos, charge, 
				 rec->lattice, rec->force);

#ifdef PME_CONSERVE_ENERGY 
    { static int firstime = 1;
      if (firstime) {
	printf("do not substract out the extra force in PME !\n");
	firstime = 0;
      }
    }
#else
    remove_net_force(rec->force, natoms);
#endif

  return;
}
#endif


void scale_coordinates(MD_Dvec p[], const MD_Int N, 
		       const struct Lattice_Tag lattice, 
		       const struct PmeGrid_Tag grid)
{
  MD_Dvec corner = lattice.corner;
  MD_Dvec b1, b2, b3;
  MD_Double k1 = (MD_Double) grid.K1;
  MD_Double k2 = (MD_Double) grid.K2;
  MD_Double k3 = (MD_Double) grid.K3;
  MD_Int i;
  MD_Dvec t;

  MD_vec_mul(lattice.b1, k1, b1);
  MD_vec_mul(lattice.b2, k2, b2);
  MD_vec_mul(lattice.b3, k3, b3);

  for (i=0; i<N; i++) {
    MD_vec_substract(p[i], corner, t);
    p[i].x = MD_vec_dot(b1, t);
    p[i].y = MD_vec_dot(b2, t);
    p[i].z = MD_vec_dot(b3, t);
    ASSERT(0 <= p[i].x && p[i].x < k1);
    ASSERT(0 <= p[i].y && p[i].y < k2);
    ASSERT(0 <= p[i].z && p[i].z < k3);
  }

#ifdef DEBUG_PME
  /*
  output_vec_array(p, N, "scaled position");
  */
#endif
}


void remove_net_force(MD_Dvec *f, MD_Int natoms)
{
  MD_Dvec ftot = {0.0, 0.0, 0.0};
  MD_Double neg_inv_natoms = -1.0/(MD_Double) natoms;
  MD_Int i;

  for (i=0; i<natoms; i++) MD_vec_add(ftot, f[i], ftot);
#ifdef DEBUG_PME
  printf("ftot=%f, %f, %f\n", ftot.x, ftot.y, ftot.z);
#endif
  MD_vec_mul(ftot, neg_inv_natoms, ftot);
  for (i=0; i<natoms; i++) MD_vec_add(f[i], ftot, f[i]);  

  return;
}


/***************************************************************************
 *
 * dipole computation functions
 *
 ***************************************************************************/

void pmerec_dipole_setup(struct PmeRecip_Tag* rec, const MD_Dvec *pos,
			 const MD_Double *charge)
{
  MD_Int natoms = rec->storePmeParams.natoms;

  memcpy(rec->scaled_pos, pos, natoms*sizeof(*pos));
  scale_coordinates(rec->scaled_pos, natoms, rec->lattice, rec->myGrid);
  pmerealspace_fill_bspline2(rec->myRealSpace, rec->scaled_pos);
  pmerealspace_fill_charges(rec->myRealSpace, rec->scaled_pos, charge, 
			    rec->q_arr);
#ifdef DEBUG_PME
  {MD_Double qsum = array_sum(rec->q_arr, rec->myGrid.K1 * rec->myGrid.dim2 *
			      rec->myGrid.dim3);
    if (fabs(qsum) > DEBUG_EPSILON) printf("qsum=%g\n", qsum);
  }
#endif

  pmekspace_init_exp(rec->myKSpace, rec->lattice, 
		     rec->storePmeParams.ewaldcof);
}


/* add the contribution to pesudores */
void pmerec_add_pseudores(const struct PmeRecip_Tag* rec,
			  const MD_Double *dipole, 
			  const MD_Int flag, MD_Double *pseudores)
{
  MD_Int i;
  MD_Int gridlen = rec->myGrid.K1 * rec->myGrid.dim2 * rec->myGrid.dim3;
  MD_Int reslen = 3 * rec->storePmeParams.natoms;

  pmerealspace_fill_dipoles(rec->myRealSpace, rec->scaled_pos, rec->lattice,
			    dipole, rec->d_arr);
#ifdef DEBUG_PME
  {
    MD_Double dsum = array_sum(rec->d_arr, gridlen);
    if (fabs(dsum) > DEBUG_EPSILON)  printf("dsum=%g\n", dsum);
  }
#endif

  if (1==flag) for(i=0; i<gridlen; i++) rec->d_arr[i] += rec->q_arr[i];


#ifdef HAVE_FFTW
  rfftwnd_one_real_to_complex(rec->forward_plan, rec->d_arr, NULL);
  pmekspace_apply_potential(rec->myKSpace, rec->lattice, rec->d_arr);
  rfftwnd_one_complex_to_real(rec->backward_plan, (fftw_complex *)rec->d_arr, 
			      NULL);
  pmerealspace_ungrid_dipoles(rec->myRealSpace, rec->scaled_pos, rec->lattice,
			      rec->d_arr,  rec->work_arr);
#else
  NEED_FFTW(1);
#endif

  for (i=0; i<reslen; i++) pseudores[i] -= rec->work_arr[i];
}


void pmerec_dipole_calc(struct PmeRecip_Tag* rec, const MD_Double *charge, 
			const MD_Double *dipole, const MD_Int calc_d_arr)
{
  MD_Int i;
  MD_Int gridlen = rec->myGrid.K1 * rec->myGrid.dim2 * rec->myGrid.dim3;

  ASSERT(lattice_is_orthogonal(&(rec->lattice)));

  if (calc_d_arr) {
    pmerealspace_fill_dipoles(rec->myRealSpace, rec->scaled_pos, rec->lattice,
			      dipole, rec->d_arr);
#ifdef DEBUG_PME
    {
      MD_Double dsum = array_sum(rec->d_arr, gridlen);
      if (fabs(dsum) > DEBUG_EPSILON)  printf("dsum=%g\n", dsum);
    }
#endif

    for (i=0; i<gridlen; i++) rec->d_arr[i] += rec->q_arr[i]; 
  
#ifdef HAVE_FFTW
    rfftwnd_one_real_to_complex(rec->forward_plan, rec->d_arr, NULL);
    pmekspace_apply_potential(rec->myKSpace, rec->lattice, rec->d_arr);
    /*  pmekspace_dipole_calc_enrgvir(rec->myKSpace, rec->d_arr, rec->lattice, 
	rec->storePmeParams.ewaldcof,
	&(rec->potential), rec->virial);*/
    rfftwnd_one_complex_to_real(rec->backward_plan, 
				(fftw_complex *)rec->d_arr, NULL);
#else
    NEED_FFTW(1);
#endif
  } /* or, d_arr is already calculated when solving the dipole */

  rec->potential = 0.5 * DOT(rec->q_arr, rec->d_arr, gridlen);
#ifdef DEBUG_PME
  printf("recE = %20.15f\n",rec->potential);
#endif

  pmerealspace_dipole_calc_force(rec->myRealSpace, rec->d_arr, 
				 rec->scaled_pos, charge, dipole, 
				 rec->lattice, rec->force);
  
#ifdef PME_CONSERVE_ENERGY
  { static int firstime = 1;
    if (firstime) {
      printf("do not substract out the extra force in PME !\n");
      firstime = 0;
    }
  } 
#else
  remove_net_force(rec->force, rec->storePmeParams.natoms);
#endif

  return;
}


