/*
 * Copyright (C) 2004-2006 by Wei Wang.  All rights reserved.
 */

/***************************************************************
 *
 * compute the matrix elements of G0, G1, G2 and G3, then output it into file
 *
 ***************************************************************/
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <math.h>
#include <string.h>
#include "helper.h"
#include "utilities.h"
#include "explicitG.h"
#include "standEwald_dir.h"

static MD_Double* G0tot;  /* they are NULL if not initialized */
static MD_Double* G1tot;
static MD_Double* G2tot;  
static MD_Double* G3tot;
static MD_Double* G2dir;
static MD_Double* G2rec;

static void 
compute_g_rec(const MD_Dvec *reclatt, const MD_Double *recU,
	      const MD_Int nreclatt, const MD_Dvec r, 
	      MD_Double *g0, MD_Dvec *g1, MD_Double g2[6], MD_Double g3[10]);


enum {
AAA=0, AAB, AAC, ABA, ABB, ABC, ACA, ACB, ACC, 
  BAA, BAB, BAC, BBA, BBB, BBC, BCA, BCB, BCC,
  CAA, CAB, CAC, CBA, CBB, CBC, CCA, CCB, CCC
};


void explicitG(const struct standEwald_Tag* se, 
	       const char *filename,
	       MD_Double** g0tot, 
	       MD_Double** g1tot,
	       MD_Double** g2tot,
	       MD_Double** g3tot)
{
  const MD_Int natoms = se->natoms;
  const MD_Int nG0 = natoms;
  const MD_Int nG2 = 3 * natoms;
  const MD_Dvec *pos = se->ppos;
  const MD_Dvec systemsize = se->systemsize;
  MD_Double g0dir, g2dir[NIND2], g3dir[NIND3];
  MD_Double g0rec, g2rec[NIND2], g3rec[NIND3];
  MD_Double g0, *g3;
  MD_Dvec g1dir, g1rec, g1;
  MD_Dvec r_ij;
  MD_Int i, j;
  MD_Int ind_ij, ind_ji;
  FILE *outfile;

  assert(0==systemsize.x - systemsize.y && 0 == systemsize.x - systemsize.z);

  G0tot = my_calloc((size_t) (natoms*natoms), sizeof(MD_Double), "G0tot");
  G1tot = my_calloc((size_t) (natoms*natoms*3), sizeof(MD_Double), "G1tot");
  G2tot = my_calloc((size_t) (nG2*nG2), sizeof(MD_Double), "G2tot");
  G3tot = my_calloc((size_t) (natoms*(natoms-1)/2*27), sizeof(MD_Double), 
		    "G3tot");

  /* compute diagonal elements of G2 matrix */
  r_ij.x = r_ij.y = r_ij.z = 0.0;
  compute_g_rec(se->reclatt, se->recU, se->nreclatt,
		r_ij, &g0rec, &g1rec, g2rec, g3rec); /* diagonal elements */
  if (MD_vec_dot(g1rec, g1rec) > 1e-20) { /* 0 in theory */
    fprintf(stderr, "g1recdiag != 0, %g, %g, %g\n", g1rec.x, g1rec.y, g1rec.z);
  }
  if (DOT(g3rec, g3rec, NIND3) > 1e-20) { /* 0 in theory */
    MD_Int k;
    fprintf(stderr, "g3recdiag != 0\n");
    for (k = 0; k < NIND3; k++) {
      fprintf(stderr, "  %d %f\n", k, g3rec[k]);
    }
  }
  /* the direct, reciprocal sum of G1, G3 diagonal blocks are zero */
  for (i = 0; i < natoms; i++) { 
    G0tot[i*natoms+i] = g0rec + se->diagdirG0;
    ind_ij = (i*3)*nG2 + (i*3);
    G2tot[ind_ij+X] = se->diagdirG2xx + g2rec[XX];
    G2tot[ind_ij+Y] = g2rec[XY];
    G2tot[ind_ij+Z] = g2rec[XZ];
    ind_ij += nG2;
    G2tot[ind_ij+X] = g2rec[YX];
    G2tot[ind_ij+Y] = se->diagdirG2yy + g2rec[YY];
    G2tot[ind_ij+Z] = g2rec[YZ];
    ind_ij += nG2;
    G2tot[ind_ij+X] = g2rec[ZX];
    G2tot[ind_ij+Y] = g2rec[ZY];
    G2tot[ind_ij+Z] = se->diagdirG2zz + g2rec[ZZ];
  }

  /* nondiagonal elements, we WANT to compute the result in this direct way,
   * (although O(N^2) cost). and we WANT to avoid using any data structure
   * in standEwald module.  */
  g3 = G3tot;
  for (i = 0; i < natoms; i++) {
    for (j = 0; j < i; j++) {
      MD_vec_substract(pos[i], pos[j], r_ij);
      SIMPLE_BOUND_VEC(r_ij, systemsize);
      assert (r_ij.x < 0.5*systemsize.x && r_ij.x > -0.5*systemsize.x &&
              r_ij.y < 0.5*systemsize.y && r_ij.y > -0.5*systemsize.y &&
              r_ij.z < 0.5*systemsize.z && r_ij.z > -0.5*systemsize.z);
      dipole_compute_gbar_dir(se, r_ij, i, j, &g0dir, &g1dir, g2dir, 
			      g3dir);
      compute_g_rec(se->reclatt, se->recU, se->nreclatt,
		    r_ij, &g0rec, &g1rec, g2rec, g3rec);
      g0 = g0dir + g0rec;
      G0tot[i*natoms+j] = G0tot[j*natoms+i] = g0;
      ind_ij = i*3*natoms + j;
      ind_ji = j*3*natoms + i;
      MD_vec_add(g1dir, g1rec, g1);
      ind_ij = i*3*natoms+j;        ind_ji = j*3*natoms+i;
      G1tot[ind_ij      ] = g1.x; G1tot[ind_ji      ] = -g1.x; 
      G1tot[ind_ij+  nG0] = g1.y; G1tot[ind_ji+  nG0] = -g1.y;
      G1tot[ind_ij+2*nG0] = g1.z; G1tot[ind_ji+2*nG0] = -g1.z;
      ind_ij = i*3*nG2 + j*3;
      ind_ji = j*3*nG2 + i*3;
      G2tot[ind_ij+X] = G2tot[ind_ji+X] = g2dir[XX] + g2rec[XX]; 
      G2tot[ind_ij+Y] = G2tot[ind_ji+Y] = g2dir[XY] + g2rec[XY];
      G2tot[ind_ij+Z] = G2tot[ind_ji+Z] = g2dir[XZ] + g2rec[XZ];
      ind_ij += nG2;
      ind_ji += nG2;
      G2tot[ind_ij+X] = G2tot[ind_ji+X] = g2dir[YX] + g2rec[YX]; 
      G2tot[ind_ij+Y] = G2tot[ind_ji+Y] = g2dir[YY] + g2rec[YY]; 
      G2tot[ind_ij+Z] = G2tot[ind_ji+Z] = g2dir[YZ] + g2rec[YZ];
      ind_ij += nG2; 
      ind_ji += nG2;
      G2tot[ind_ij+X] = G2tot[ind_ji+X] = g2dir[ZX] + g2rec[ZX]; 
      G2tot[ind_ij+Y] = G2tot[ind_ji+Y] = g2dir[ZY] + g2rec[ZY]; 
      G2tot[ind_ij+Z] = G2tot[ind_ji+Z] = g2dir[ZZ] + g2rec[ZZ];
      g3[AAA] = g3dir[XXX] + g3rec[XXX];
      g3[AAB] = g3dir[XXY] + g3rec[XXY];
      g3[AAC] = g3dir[XXZ] + g3rec[XXZ];
      g3[ABA] = g3dir[XYX] + g3rec[XYX];
      g3[ABB] = g3dir[XYY] + g3rec[XYY];
      g3[ABC] = g3dir[XYZ] + g3rec[XYZ];
      g3[ACA] = g3dir[XZX] + g3rec[XZX];
      g3[ACB] = g3dir[XZY] + g3rec[XZY];
      g3[ACC] = g3dir[XZZ] + g3rec[XZZ];
      g3[BAA] = g3dir[YXX] + g3rec[YXX];
      g3[BAB] = g3dir[YXY] + g3rec[YXY];
      g3[BAC] = g3dir[YXZ] + g3rec[YXZ];
      g3[BBA] = g3dir[YYX] + g3rec[YYX];
      g3[BBB] = g3dir[YYY] + g3rec[YYY];
      g3[BBC] = g3dir[YYZ] + g3rec[YYZ];
      g3[BCA] = g3dir[YZX] + g3rec[YZX];
      g3[BCB] = g3dir[YZY] + g3rec[YZY];
      g3[BCC] = g3dir[YZZ] + g3rec[YZZ];
      g3[CAA] = g3dir[ZXX] + g3rec[ZXX];
      g3[CAB] = g3dir[ZXY] + g3rec[ZXY];
      g3[CAC] = g3dir[ZXZ] + g3rec[ZXZ];
      g3[CBA] = g3dir[ZYX] + g3rec[ZYX];
      g3[CBB] = g3dir[ZYY] + g3rec[ZYY];
      g3[CBC] = g3dir[ZYZ] + g3rec[ZYZ];
      g3[CCA] = g3dir[ZZX] + g3rec[ZZX];
      g3[CCB] = g3dir[ZZY] + g3rec[ZZY];
      g3[CCC] = g3dir[ZZZ] + g3rec[ZZZ];
      g3 += 27;
    }
  }

  if (NULL != filename) { /* sometimes we only need G2, not output */
    size_t noutput;
    MD_String fname;
    if (sprintf(fname, "%s.G0", filename) + 1 >
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G0tot, sizeof(MD_Double), (size_t)(natoms*natoms),
		     outfile);
    if ((size_t) nG0 * nG0 != noutput || ferror(outfile)) {
      perror("cannot output G0 matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close G0 matrix file\n");
    }

    if (sprintf(fname, "%s.G1", filename) + 1 > 
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;      
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G1tot, sizeof(MD_Double), (size_t)(natoms*natoms*3),
		     outfile);
    if ((size_t) natoms * natoms * 3 != noutput || ferror(outfile)) {
      perror("cannot output G0 matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close G1 matrix");
    }

    if (sprintf(fname, "%s.G2", filename) + 1 > 
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;      
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G2tot, sizeof(MD_Double), (size_t)(nG2*nG2), 
		     outfile);
    if ((size_t)(nG2*nG2) != noutput || ferror(outfile)) {
      perror("cannot output G2 matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close file storing matrix G2");
      return;
    }

    if (sprintf(fname, "%s.G3", filename) + 1 > 
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;      
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G3tot, sizeof(MD_Double), 
		     (size_t) (natoms*(natoms-1)/2*27), outfile);
    if ((size_t)(natoms*(natoms-1)/2*27) != noutput || ferror(outfile)) {
      perror("cannot output G3 matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close file storing matrix G2");
      return;
    }

  }

  *g0tot = G0tot;
  *g1tot = G1tot;
  *g2tot = G2tot;
  *g3tot = G3tot;
}


void explicitG2(const struct standEwald_Tag* se, 
	       const char *filename,
	       MD_Double** ig0tot, 
	       MD_Double** ig1tot,
	       MD_Double** ig2dir,
	       MD_Double** ig2rec)
{
  const MD_Int natoms = se->natoms;
  const MD_Int nG0 = natoms;
  const MD_Int nG2 = 3 * natoms;
  const MD_Dvec *pos = se->ppos;
  const MD_Dvec systemsize = se->systemsize;
  MD_Double g0dir, g2dir[NIND2], g3dir[NIND3];
  MD_Double g0rec, g2rec[NIND2], g3rec[NIND3];
  MD_Double g0;
  MD_Dvec g1dir, g1rec, g1;
  MD_Dvec r_ij;
  MD_Int i, j;
  MD_Int ind_ij, ind_ji;
  FILE *outfile;

  G0tot = my_calloc((size_t) (natoms*natoms), sizeof(MD_Double), "G0tot");
  G1tot = my_calloc((size_t) (natoms*natoms*3), sizeof(MD_Double), "G1tot");
  G2dir = my_calloc((size_t) (nG2*nG2), sizeof(MD_Double), "G2dir");
  G2rec = my_calloc((size_t) (nG2*nG2), sizeof(MD_Double), "G2rec");


  /* compute diagonal elements of G2 matrix */
  r_ij.x = r_ij.y = r_ij.z = 0.0;
  compute_g_rec(se->reclatt, se->recU, se->nreclatt,
		r_ij, &g0rec, &g1rec, g2rec, g3rec); /* diagonal elements */
  if (MD_vec_dot(g1rec, g1rec) > 1e-20) { /* 0 in theory */
    fprintf(stderr, "g1recdiag != 0, %g, %g, %g\n", g1rec.x, g1rec.y, g1rec.z);
  }
  if (DOT(g3rec, g3rec, NIND3) > 1e-20) { /* 0 in theory */
    MD_Int k;
    fprintf(stderr, "g3recdiag != 0\n");
    for (k = 0; k < NIND3; k++) {
      fprintf(stderr, "  %d %f\n", k, g3rec[k]);
    }
  }
  /* the direct, reciprocal sum of G1, G3 diagonal blocks are zero */
  for (i = 0; i < natoms; i++) { 
    G0tot[i*natoms+i] = g0rec + se->diagdirG0;
    ind_ij = (i*3)*nG2 + (i*3);
    G2rec[ind_ij+X]=g2rec[XX] + se->diagdirG2xx; 
    G2rec[ind_ij+Y]=g2rec[XY]; 
    G2rec[ind_ij+Z]=g2rec[XZ];
    ind_ij += nG2;
    G2rec[ind_ij+X]=g2rec[YX]; 
    G2rec[ind_ij+Y]=g2rec[YY] + se->diagdirG2yy; 
    G2rec[ind_ij+Z]=g2rec[YZ]; 
    ind_ij += nG2;
    G2rec[ind_ij+X]=g2rec[ZX]; 
    G2rec[ind_ij+Y]=g2rec[ZY]; 
    G2rec[ind_ij+Z]=g2rec[ZZ] + se->diagdirG2zz; 
  }

  /* nondiagonal elements, we WANT to compute the result in this direct way,
   * (although O(N^2) cost). and we WANT to avoid using any data structure
   * in standard Ewald module.  */
  for (i = 0; i < natoms; i++) {
    for (j = 0; j < i; j++) {
      MD_vec_substract(pos[i], pos[j], r_ij);
      SIMPLE_BOUND_VEC(r_ij, systemsize);
      assert (r_ij.x < 0.5*systemsize.x && r_ij.x > -0.5*systemsize.x &&
	      r_ij.y < 0.5*systemsize.y && r_ij.y > -0.5*systemsize.y &&
	      r_ij.z < 0.5*systemsize.z && r_ij.z > -0.5*systemsize.z);
      dipole_compute_gbar_dir(se, r_ij, i, j, &g0dir, &g1dir, g2dir, 
			      g3dir);
      compute_g_rec(se->reclatt, se->recU, se->nreclatt,
		    r_ij, &g0rec, &g1rec, g2rec, g3rec);
      g0 = g0dir + g0rec;
      G0tot[i*natoms+j] = G0tot[j*natoms+i] = g0;
      ind_ij = i*3*natoms + j;
      ind_ji = j*3*natoms + i;
      MD_vec_add(g1dir, g1rec, g1);
      ind_ij = i*3*natoms+j;        ind_ji = j*3*natoms+i;
      G1tot[ind_ij      ] = g1.x; G1tot[ind_ji      ] = -g1.x; 
      G1tot[ind_ij+  nG0] = g1.y; G1tot[ind_ji+  nG0] = -g1.y;
      G1tot[ind_ij+2*nG0] = g1.z; G1tot[ind_ji+2*nG0] = -g1.z;
      ind_ij = i*3*nG2 + j*3;
      ind_ji = j*3*nG2 + i*3;
      G2dir[ind_ij+X] = G2dir[ind_ji+X] = g2dir[XX]; 
      G2dir[ind_ij+Y] = G2dir[ind_ji+Y] = g2dir[XY];
      G2dir[ind_ij+Z] = G2dir[ind_ji+Z] = g2dir[XZ];
      ind_ij += nG2;
      ind_ji += nG2;
      G2dir[ind_ij+X] = G2dir[ind_ji+X] = g2dir[YX]; 
      G2dir[ind_ij+Y] = G2dir[ind_ji+Y] = g2dir[YY]; 
      G2dir[ind_ij+Z] = G2dir[ind_ji+Z] = g2dir[YZ];
      ind_ij += nG2; 
      ind_ji += nG2;
      G2dir[ind_ij+X] = G2dir[ind_ji+X] = g2dir[ZX]; 
      G2dir[ind_ij+Y] = G2dir[ind_ji+Y] = g2dir[ZY]; 
      G2dir[ind_ij+Z] = G2dir[ind_ji+Z] = g2dir[ZZ];

      ind_ij = i*3*nG2 + j*3;
      ind_ji = j*3*nG2 + i*3;
      G2rec[ind_ij+X] = G2rec[ind_ji+X] = g2rec[XX]; 
      G2rec[ind_ij+Y] = G2rec[ind_ji+Y] = g2rec[XY];
      G2rec[ind_ij+Z] = G2rec[ind_ji+Z] = g2rec[XZ];
      ind_ij += nG2;
      ind_ji += nG2;
      G2rec[ind_ij+X] = G2rec[ind_ji+X] = g2rec[YX]; 
      G2rec[ind_ij+Y] = G2rec[ind_ji+Y] = g2rec[YY]; 
      G2rec[ind_ij+Z] = G2rec[ind_ji+Z] = g2rec[YZ];
      ind_ij += nG2; 
      ind_ji += nG2;
      G2rec[ind_ij+X] = G2rec[ind_ji+X] = g2rec[ZX]; 
      G2rec[ind_ij+Y] = G2rec[ind_ji+Y] = g2rec[ZY]; 
      G2rec[ind_ij+Z] = G2rec[ind_ji+Z] = g2rec[ZZ];
    }
  }

  if (NULL != filename) { /* sometimes we only need G2, not output */
    size_t noutput;
    MD_String fname;
    if (sprintf(fname, "%s.G2dir", filename) + 1 > 
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;      
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G2dir, sizeof(MD_Double), (size_t)(nG2*nG2), 
		     outfile);
    if ((size_t)(nG2*nG2) != noutput || ferror(outfile)) {
      perror("cannot output G2dir matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close file storing matrix G2dir");
      return;
    }
    if (sprintf(fname, "%s.G2rec", filename) + 1 > 
	(MD_Int)sizeof(MD_String)) {
      fprintf(stderr, "%s too long\n", filename);
      return;      
    }
    outfile = fopen(fname, "wb");
    assert(NULL != outfile);
    noutput = fwrite(G2rec, sizeof(MD_Double), (size_t)(nG2*nG2), 
		     outfile);
    if ((size_t)(nG2*nG2) != noutput || ferror(outfile)) {
      perror("cannot output G2rec matrix");
      return;
    }
    if (fclose(outfile)) {
      perror("cannot close file storing matrix G2rec");
      return;
    }

  }

  *ig0tot = G0tot;
  *ig1tot = G1tot;
  *ig2dir = G2dir;
  *ig2rec = G2rec;
}


void explicitG_compute_G2d(const MD_Double* d, const int matrixsize)
{
  const MD_Double *matrix = G2tot;
  MD_Double sum;
  MD_Int i,j;

  printf("G2*d=\n");
  for (i = 0; i < matrixsize; i++) {
    sum = 0.0;
    for (j = 0; j < matrixsize; j++) {
      sum += matrix[j] * d[j];
    }
    printf("%d %20.15g\n", i, sum);
    matrix += matrixsize;
  }
}


void explicitG_destroy(void) 
{
  if (NULL != G0tot) free(G0tot);
  if (NULL != G1tot) free(G1tot);
  if (NULL != G2tot) free(G2tot);
  if (NULL != G3tot) free(G3tot);
  if (NULL != G2dir) free(G2dir);
  if (NULL != G2rec) free(G2rec);
}


/*
 * G0^{rec}_{ij} = g_rec(r_{ij})
 *                               
 * g_rec (r) = sum_k U(k) * exp(i*dot(k,r)) =  
 *
 *          4*pi    exp(- k^2 / (4*beta^2) )
 *     sum ------ ------------------------- exp(i*dot(k,r))
 *           V              k^2
 *
 */

static void 
compute_g_rec(const MD_Dvec *reclatt, const MD_Double *recU,
	      const MD_Int nreclatt, const MD_Dvec r, 
	      MD_Double *g0, MD_Dvec *g1, MD_Double g2[NIND2], MD_Double g3[NIND3])
{
  MD_Double G0 = 0.0;
  MD_Dvec G1 = {0.0, 0.0, 0.0};
  MD_Double G2[NIND2] = {0.0};
  MD_Double G3[NIND3] = {0.0};
  MD_Int ireclatt;
  MD_Double poten, potsin, potcos;
  MD_Double kx, ky, kz;
  MD_Double angle;

  for (ireclatt = 0; ireclatt < nreclatt; ireclatt++) {
    kx = reclatt[ireclatt].x;
    ky = reclatt[ireclatt].y;
    kz = reclatt[ireclatt].z;
    angle = kx*r.x + ky*r.y + kz*r.z;
    poten = recU[ireclatt];
    potcos = poten * cos(angle);
    potsin = poten * sin(angle);
    G0 += potcos;
    G1.x -= potsin * kx;
    G1.y -= potsin * ky;
    G1.z -= potsin * kz;
    G2[XX] += potcos * kx * kx;
    G2[XY] += potcos * kx * ky;
    G2[XZ] += potcos * kx * kz;
    G2[YY] += potcos * ky * ky;
    G2[YZ] += potcos * ky * kz;
    G2[ZZ] += potcos * kz * kz;
    G3[XXX] -= potsin * kx * kx * kx;
    G3[XXY] -= potsin * kx * kx * ky;
    G3[XXZ] -= potsin * kx * kx * kz;
    G3[XYY] -= potsin * kx * ky * ky;
    G3[XYZ] -= potsin * kx * ky * kz;
    G3[XZZ] -= potsin * kx * kz * kz;
    G3[YYY] -= potsin * ky * ky * ky;
    G3[YYZ] -= potsin * ky * ky * kz;
    G3[YZZ] -= potsin * ky * kz * kz;
    G3[ZZZ] -= potsin * kz * kz * kz;
  }
  *g0 = G0; /* reversal symmetry is already considered */
  g1->x = G1.x;
  g1->y = G1.y;
  g1->z = G1.z;
  memcpy(g2, G2, NIND2*sizeof(MD_Double));
  memcpy(g3, G3, NIND3*sizeof(MD_Double));
}
