/***************************************************************************
 *cr
 *cr            (C) Copyright 2006 The Board of Trustees of the
 *cr                        University of Illinois
 *cr                         All Rights Reserved
 *cr
 ***************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <unistd.h>
#include <string.h>
#include <sys/time.h>

#include "util.h"
#include "energythr.h"

int get_opts(float*, long int*, float*, float*, float*, char*, float*, char*, char*, int*, int, char**);
int get_pdb_info(FILE*, float*, float*, float*, float*, float*, float*, const float, const float, long int*, long int*, long int*, long int*);
int read_pdb_file(FILE*, float*, long int*, const float, const float, const float, const float, const int);

int calc_grid_energies_excl(const float*, float*, const long int, const long int, const long int, const long int, const float, const unsigned char*);
int exclude_solute(const float*, unsigned char*,const float, const long int, const long int, const long int, const long int, const float);
int place_n_ions(const int, const float, float*, int*, unsigned char*, const long int, const long int, const long int, const float, const float);
int output_ions(char*, const int*, const char*, const int, const float, const float, const float, const float, const float);
void print_usage();


int main(int argc, char* argv[]) {
  /*** declare all the variables we need to pass around ***/
  /*
   * Note that all coordinates are laid out on a grid, with integral
   * points at spacings determined at runtime. The grid dimensions are
   * numplane in the z direction, numcol in the y direction, and numpt in
   * the x direction
   */
  long int natom;   /* number of atoms */
  float *atoms;     /* atom coordinates and charges, stored x/y/z/q/x/y/z/q/... */
  long int *atomis; /* atom grid cell indices, stored x/y/z/x/y/z... */
  float *grideners; /* energies at grid cells */
  unsigned char *excludepos; /* excluded grid cells; nonzero == excluded */
  int *ionpos;     /* output ion positions, laid out x/y/z/x/y/z... */


  /*** Grid parameters ***/
  long int numplane, numcol, numpt;
  float gridspacing;
  float bordersize;
  float minx, miny, minz;
  float maxx, maxy, maxz;

  /***Ion information***/
  long int nion;
  char ionname[4];
  float ioncharge;
  float r_ion_prot;
  float r_ion_ion;

  /*** File information ***/
  FILE* pdb_in;
  char pdbin[50];
  char pdbout[50]; /* Output file for ions */

  /*** Threading parameters ***/
  int maxnumprocs; /* Maximum available number of processors */

  printf("Ionize -- Multiprocessor Ion Placement Tool       Version 0.1\n");
  printf("(C) 2006 Theoretical and Computational Biophysics Group, UIUC\n");
  printf("-------------------------------------------------------------\n");


  /***Begin doing work***/
  /* Parse the command line and read the input pdb */
  if (get_opts(&gridspacing, &nion, &r_ion_prot, &r_ion_ion, &bordersize, ionname, &ioncharge, pdbin, pdbout, &maxnumprocs, argc, argv) != 0) exit(1);

  /* Sanity check the input */
  if (nion <= 0) {
    fprintf(stderr, "ERROR: You need to place one or more ions!\n");
    return 1;
  }

  printf("Running ionize on input file %s\n"
         "\tIons to be placed: %ld\n"
         "\tIon name: %s\n\tIon charge: %f\n"
         "\tIon-solute distance: %f\n"
         "\tIon-Ion distance: %f\n"
         "\tGrid spacing: %f\n"
         "\tBoundary size: %f\n"
         "\tMax. Processors: %i\n"
         "Placed ions will be printed to the file %s\n", pdbin, nion, ionname, ioncharge, r_ion_prot, r_ion_ion, gridspacing, bordersize, maxnumprocs, pdbout);

  /** Open and read the input pdb file ***/
  pdb_in=fopen(pdbin,"r");
  if (pdb_in == NULL) {
    fprintf(stderr, "Error: Couldn't open input pdb file %s. Exiting...", pdbin);
    return 1;
  }

  /* First, read through the pdb file once to get enough information to allocate all the arrays */
  if (get_pdb_info(pdb_in, &minx, &miny, &minz, &maxx, &maxy, &maxz, gridspacing, bordersize, &numplane, &numcol, &numpt, &natom) != 0) return 1;

  printf("\nExtents of the considered system (including boundary) are:\n"
         "\tX: %8.3f to %8.3f\tY: %8.3f to %8.3f\tZ: %8.3f to %-8.3f\n"
         "Grid spacing is %f angstroms, dimensions %6ld x %6ld x %6ld\n", 
          minx, maxx, miny, maxy, minz, maxz, 
          gridspacing, numpt, numcol, numplane);

  /* Now allocate the arrays we need */
  printf("\nAllocating memory for data arrays...\n");
  printf("\tAllocating %ld KB for atom arrays\n", (4*natom*sizeof(float)/(1024)));
  atoms = malloc(4*natom*sizeof(float));
  printf("\tAllocating %ld KB for atom grid point array\n", (3*natom*sizeof(long int)/(1024)));
  atomis = malloc(3*natom*sizeof(long int));
  printf("\tAllocating %ld MB for grid energy array\n", (numpt*numcol*numplane*sizeof(float)/(1024*1024)));
  grideners = malloc(numpt*numcol*numplane*sizeof(float));
  printf("\tAllocating %ld KB for ion array\n", (3*nion*sizeof(long int)/(1024)));
  ionpos = malloc(3*nion*sizeof(int));
  printf("\tAllocating %ld MB for exclusion array\n", (numpt*numcol*numplane*sizeof(unsigned char)/(1024*1024)));
  excludepos = calloc(numpt*numcol*numplane, sizeof(unsigned char));

  if (atoms==NULL || atomis==NULL || grideners==NULL || ionpos==NULL) {
    fprintf(stderr, "Error: Failed to allocate memory for data arrays\n");
    return 1;
  }

  printf("Successfully allocated data arrays\n");

  /* Read the pdb file again, and this time fill the atom position and charge arrays */
  printf("\nReading atoms from pdb file\n");
  if (read_pdb_file(pdb_in, atoms, atomis, minx, miny, minz, gridspacing, natom) != 0) return 1;
  fclose(pdb_in);
  printf("Done with input pdb\n");

  /* Exclude grid points too close to the protein */
  printf("\nExcluding grid points too close to protein\n");
  exclude_solute(atoms, excludepos, r_ion_prot, numplane, numcol, numpt, natom, gridspacing);
  printf("Finished with exclusion\n");

  /* Now that all our input is done, calculate the initial grid energies */
  printf("\nCalculating grid energies...\n");
#if 1
  /* Multithreaded and vectorized implementation by John Stone */
  calc_grid_energies_excl_thr(atoms, grideners, numplane, numcol, numpt, natom, gridspacing, excludepos, maxnumprocs);
#else
  /* Sequential implementation */
  calc_grid_energies_excl(atoms, grideners, numplane, numcol, numpt, natom, gridspacing, excludepos);
#endif
  printf("Done calculating grid energies\n");


  /* Place the ions by finding the minimum energy points */
  printf("\nPlacing ions...\n");
  place_n_ions(nion, ioncharge, grideners, ionpos, excludepos, numplane, numcol, numpt, gridspacing, r_ion_ion);
  printf("Finished placing ions\n");

  /* Output ion positions to pdb file */
  printf("\nPrinting final ion coordinates to PDB file %s\n",pdbout);
  output_ions(pdbout, ionpos, ionname, nion, ioncharge, gridspacing, minx, miny, minz);
  printf("Finished writing output\n");

  /* Free allocated memory */
  if (atoms != NULL) free(atoms);
  if (atomis != NULL) free(atomis);
  if (grideners != NULL) free(grideners);
  if (ionpos != NULL) free(ionpos);
  if (excludepos != NULL) free(excludepos);

  printf("\nionize: normal exit\n");
  return 0;
}


int get_opts(float* gridspacing, long int* nion, float* r_ion_prot, float* r_ion_ion, float* bordersize, char* ionname, float* ioncharge, char* pdbin, char* pdbout, int* maxnumprocs, int argc, char** argv) {
  if (argc == 0) {
    print_usage(); 
    return 1;
  }
  /* Parse the command line arguments */
  *nion=1;
  *r_ion_prot=6.0;
  *r_ion_ion=10.0;
  *bordersize=10.0;
  *gridspacing = 0.5;
  char* iontype = "SOD\0";
  *ioncharge = 1.0;
  int c;
  *maxnumprocs = 1;
  int i;

  if (argc == 1) {
    print_usage();
    return 1;
  }


/*  while ((c = getopt (argc, argv, "n:r:i:g:b:t:q:p:")) != -1) { */
    for (i=1; i<(argc-2); i++) {
    c = *(argv[i]+1);
    switch (c) {
      case 'n':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *nion = atoi(argv[i]);
        break;
      case 'r':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *r_ion_prot = atof(argv[i]);
        break;
      case 'i':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *r_ion_ion = atof(argv[i]);
        break;
      case 'g':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *gridspacing = atof(argv[i]);
        break;
      case 'b':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *bordersize = atof(argv[i]);
        break;
      case 't':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        iontype = argv[i];
        break;
      case 'q':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *ioncharge = atof(argv[i]);
        break;
      case 'p':
        i++;
        if (i >= argc) {
          fprintf(stderr, "Error: No argument for option `-%c'.\n", c);
          print_usage();
          return 1;
        }
        *maxnumprocs = atoi(argv[i]);
        break;
      default:
        fprintf (stderr, "Unknown option `-%c'.\n", optopt);
        print_usage();
        return 1;
    }
  }

  strncpy(ionname, iontype, 3);
  ionname[3]='\0';

  /*
  if (argc-optind != 2) {
    fprintf(stderr, "The input and output pdb file are required input!\n");
    print_usage();
    return 1;
  }
  */

  strncpy(pdbin,argv[argc-2],50);
  strncpy(pdbout,argv[argc-1],50);

  return 0;
}


/* Pass through the pdb input once, find the min and max values of all
 * coordinates, and give us the number of atoms and grid points
 */
int get_pdb_info(FILE* pdb_in, float* minx, float* miny, float* minz, float* maxx, float* maxy, float* maxz, const float gridspacing, const float boundary, long int* numplane, long int* numcol, long int* numpt, long int* natom) {

  char line[81];
  char buf[10];
  buf[8]='\0';
  *maxx = *maxy = *maxz = -HUGE_VAL;
  *minx = *miny = *minz = HUGE_VAL;
  *natom = 0;
  float x,y,z; /* Values for current line */
  

  while (fgets(line, 80, pdb_in)) {
    if (strncmp(line, "ATOM", 4)!=0 && strncmp(line, "HETATM", 6)!=0) continue;

    /* get the x,y, and z values for the current line */
    strncpy(buf, &line[30], 8);
    x = atof(buf);
    strncpy(buf, &line[38], 8);
    y = atof(buf);
    strncpy(buf, &line[46], 8);
    z = atof(buf);

    if (x>*maxx) *maxx = x;
    if (y>*maxy) *maxy = y;
    if (z>*maxz) *maxz = z;
    if (x<*minx) *minx = x;
    if (y<*miny) *miny = y;
    if (z<*minz) *minz = z;

    *natom += 1;

  }

  /* Find the number of grid points needed */
  *maxx += boundary;
  *maxy += boundary;
  *maxz += boundary;
  *minx -= boundary;
  *miny -= boundary;
  *minz -= boundary;
  x = (*maxx - *minx);
  y = (*maxy - *miny);
  z = (*maxz - *minz);

  *numpt = (long int) rint(x/gridspacing) + 1;
  *numcol = (long int) rint(y/gridspacing) + 1;
  *numplane = (long int) rint(z/gridspacing) + 1;

  return 0;
}


/* Read the pdb file, and fill the atom position and charge arrays with the
 * appropriate data
 */
int read_pdb_file(FILE* pdb_in, float* atoms, long int* atomis, const float minx, const float miny, const float minz, const float gridspacing, const int natom) {
  char line[81];
  char buf[10];
  buf[8]='\0';
  float x,y,z,q;
  long int xi, yi, zi;
  int n=0; /* counter for atoms processed */

  rewind(pdb_in);

  while (fgets(line, 80, pdb_in)) {
    if (strncmp(line, "ATOM", 4)!=0 && strncmp(line, "HETATM", 6)!=0) continue;

    /* get the x,y, and z values for the current line */
    strncpy(buf, &line[30], 8);
    x = atof(buf);
    strncpy(buf, &line[38], 8);
    y = atof(buf);
    strncpy(buf, &line[46], 8);
    z = atof(buf);
    strncpy(buf, &line[61], 5);
    buf[5]='\0';
    q = atof(buf);

    x -= minx;
    y -= miny;
    z -= minz;
    /* Convert to grid points */
    xi = (long int) rint(x/gridspacing);
    yi = (long int) rint(y/gridspacing);
    zi = (long int) rint(z/gridspacing);

    /* Update arrays */
    atoms[4*n    ] = x;
    atoms[4*n + 1] = y;
    atoms[4*n + 2] = z;
    atoms[4*n + 3] = q;
    atomis[3*n    ] = xi;
    atomis[3*n + 1] = yi;
    atomis[3*n + 2] = zi;

    n += 1;

#if 0 && defined(DEBUG)
    printf("DEBUG: Assigning atom with line\n%s\tto grid point %i %i %i and charge %f\n",line, xi, yi, zi, q);
#endif
  }

  if (n != natom) {
    fprintf(stderr, "ERROR: Failed to properly read pdb file\n");
    return 1;
  }

  return 0;
}



/* Calculate the coulombic energy at each grid point from each atom
 * This is by far the most time consuming part of the process
 * We iterate over z,y,x, and then atoms
 * This function is the same as the original calc_grid_energies, except
 * that it utilizes the exclusion grid
 */
int calc_grid_energies_excl(const float* atoms, float* grideners, const long int numplane, const long int numcol, const long int numpt, const long int natoms, const float gridspacing, const unsigned char* excludepos) {
  float energy; /* Energy of current grid point */
  float x,y,z; /* Coordinates of current grid point */
  int i,j,k,n; /* Loop counters */

  /* For each point in the cube... */
  for (k=0; k<numplane; k++) {
    printf("\tWorking on plane %i of %ld\n", k, numplane);
    z = gridspacing * (float) k;
    for (j=0; j<numcol; j++) {
      y = gridspacing * (float) j;
      for (i=0; i<numpt; i++) {

        /* Check if we're on an excluded point, and skip it if we are */
        if (excludepos[k*numcol*numpt + j*numpt + i] != 0) continue;

        
        x = gridspacing * (float) i;
        energy = 0;
        /* Calculate the interaction with each atom */
        for (n=0; n<natoms; n++) {
          int arrpos = n<<2;
          float dx = x - atoms[arrpos];
          float dy = y - atoms[arrpos+1];
          float dz = z - atoms[arrpos+2];
          float r2 = dx*dx + dy*dy + dz*dz;
          float r_1 = 1.0 / sqrtf(r2);
          energy += atoms[arrpos+3] * r_1;
        }

#if defined(DEBUG)
        printf("DEBUG: Energy at location %i %i %i is %f\n", i, j, k, energy);
#endif
        grideners[numcol*numpt*k + numpt*j + i] = energy;
      }
    }
  }

  return 0;
}


/* Loop through the points in the cube and exclude any within r_ion_prot of
 * the protein
 */
int exclude_solute(const float* atoms, unsigned char* excludepos, const float r_ion_prot, const long int numplane, const long int numcol, const long int numpt, const long int natoms, const float gridspacing) {

  float x,y,z; /* coordinates of current point */
  float xa, ya, za; /* coordinates of current atom */
  int xi, yi, zi; /* Integer coordinates of current atom */
  int i,j,k,n; /* Loop counters */
  int nexcl=0;
  int rip = r_ion_prot / gridspacing; /* Number of grid points for exclude cube */
  float rip2 = r_ion_prot * r_ion_prot;

  /* Note that  this time we loop over atoms first; this is because */
  /* we only need to examine the cube of side length r_prot_ion centered on */
  /* each atom */
  for (n=0; n<natoms; n++) {
    xa = atoms[4*n];
    ya = atoms[4*n+1];
    za = atoms[4*n+2];
    xi = (int) rint(xa / gridspacing);
    yi = (int) rint(ya / gridspacing);
    zi = (int) rint(za / gridspacing);
    for (k=zi - rip; k<=zi + rip; k++) {
      if (k<0 || k>=numplane) continue;
    z = gridspacing * (float) k;
      for (j=yi - rip; j<=yi + rip; j++) {
        if (j<0 || j>=numcol) continue;
      y = gridspacing * (float) j;
        for (i=xi - rip; i<=xi + rip; i++) {
          if (i<0 || i>=numpt) continue;
        x = gridspacing * (float) i;
          /* See if this point is too close to the current atom */
          float dx = xa - x;
          float dy = ya - y;
          float dz = za - z;
          float dist = dx*dx + dy*dy + dz*dz;
          if (dist <= rip2) {
            if (excludepos[numcol * numpt * k + numpt * j + i] == 0) nexcl += 1;
            excludepos[numcol * numpt * k + numpt * j + i] = 1;
          }
        }
      }
    }
  }
  printf("\tExcluded %i points out of %ld\n", nexcl, numplane*numcol*numpt);
  
  return 0;
}


/* Place the number of ions requested. Each ion placement cycle requires
 *  -Determination of the minimum energy point
 *  -Placement of the ion
 *  -updating the energy grid to take the new ion into account
 *  -exclusion of points around the placed ion
 */
int place_n_ions(const int nion, const float ioncharge, float* grideners, int* ionpos, unsigned char* excludepos, const long int numplane, const long int numcol, const long int numpt, const float gridspacing, const float r_ion_ion) {
  int n;      /* Current ion number */
  int i,j,k;  /* Counters for going through grid */
  int offset; /* Location of current point in the grid arrays */
  int mini=-1, minj=-1, mink=-1; /* grid locations of minimum energy point */
  int rii = (int) rint(r_ion_ion/gridspacing); /* Minimum ion-ion distance, in grid units */
  int rii2 = rii*rii; /* Square of minimum distance */
  float minener;

  for (n=0; n<nion; n++) {
    minener = HUGE_VAL;
    /* Place the ion on the energetic minimum */
    /* For each point in the cube... */
    for (k=0; k<numplane; k++) {
      for (j=0; j<numcol; j++) {
        for (i=0; i<numpt; i++) {
          /* Check if this point is excluded */
          offset = numcol*numpt*k + numpt*j + i;
#if defined(DEBUG)
          printf("DEBUG: Checking exclusion at point %i %i %i: %i\n", i,j,k,excludepos[offset]);
#endif
          if (excludepos[offset] != 0) continue;
          float myener = (ioncharge * grideners[offset]);
          if (myener < minener) {
#if defined(DEBUG)
            printf("DEBUG: Better energy %f found at %i %i %i\n", myener, i, j, k);
#endif
            minener = myener;
            mini = i;
            minj = j;
            mink = k;
          }
        }
      }
    }

    /* Found the current minimum energy point; place an ion there */
    ionpos[3*n] = mini;
    ionpos[3*n + 1] = minj;
    ionpos[3*n + 2] = mink;
    printf("\tPlaced ion %i at grid location %i %i %i with energy %f\n", n, mini, minj, mink, minener);

    /* Update the energy grid with the effects of the new ion */
    float minxf = (float) (mini * gridspacing);
    float minyf = (float) (minj * gridspacing);
    float minzf = (float) (mink * gridspacing);
    float x,y,z;

    for (k=0; k<numplane; k++) {
      z = gridspacing * (float) k;
      for (j=0; j<numcol; j++) {
        y = gridspacing * (float) j;
        for (i=0; i<numpt; i++) {
          x = gridspacing * (float) i;
          float dx = x - minxf;
          float dy = y - minyf;
          float dz = z - minzf;
          float r2 = dx*dx + dy*dy + dz*dz;
          float r_1 = 1.0 / sqrtf(r2);
          float energy =  ioncharge * r_1;
          grideners[numcol * numpt * k + numpt * j + i] += energy;
        }
      }
    }

    /* Exclude anything too close to it */
    /* Only work in the space within rii of the newly placed ion */
    for (k=mink - rii; k<=mink + rii; k++) {
      if (k<0 || k>=numplane) continue;
      for (j=minj - rii; j<=minj + rii; j++) {
        if (j<0 || j>=numcol) continue;
        for (i=mini - rii; i<=mini + rii; i++) {
          if (i<0 || i>=numpt) continue;
          /* Exclude if too close to current ion */
          int dx = i-mini;
          int dy = j-minj;
          int dz = k-mink;
          int dist = dx*dx + dy*dy + dz*dz;
          if (dist < rii2) {
            excludepos[numcol*numpt*k + numpt*j + i] = 1;
#if defined(DEBUG)
            printf("DEBUG: Excluding node at %i %i %i because distance %i is less than the minimum %i\n", i, j, k, dist, rii2);
#endif
          }
        }
      }
    }
  }

  printf("\tSuccessfully placed %i ions\n", nion);

  return 0;
}

/*
 * Output all of the newly placed ions to the output file
 */
int output_ions(char* pdbout, const int* ionpos, const char* ionname, const int nion, const float ioncharge, const float gridspacing, const float minx, const float miny, const float minz) {
  FILE* pdb_out;
  pdb_out = fopen(pdbout, "w");
  if (pdb_out == NULL) {
    fprintf(stderr, "Error: Couldn't open output pdb file %s. Exiting...", pdbout);
    return 1;
  }

  int n; /* Current ion number */
  float x,y,z; /* coordinates of current ion */

  for (n=0; n<nion; n++) {
    x = ((float) (ionpos[3*n] * gridspacing)) + minx;
    y = ((float) (ionpos[3*n+1] * gridspacing)) + miny;
    z = ((float) (ionpos[3*n+2] * gridspacing)) + minz;
    fprintf(pdb_out, "ATOM%7i %3s  %3s %5i    %8.3f%8.3f%8.3f %5.2f  0.00      %3s\n", n+1, ionname, ionname, n+1, x,y,z,ioncharge, ionname);
  }

  fclose(pdb_out);

  return 0;
}


void print_usage() {
  printf("\nionize: Place ions by finding minima in a coulombic potential\n"
         "Usage: ionize (options) input.pdb output.pdb\n"
         "\tinput.pdb must be a pdb file with atom charges in the beta field\n"
         "Optional arguments (with defaults):\n"
         "\t-n number of ions to place (1)\n"
         "\t-r Minimum distance from ions to solute (6.0)\n"
         "\t-i Minimum distance between placed ions (10.0)\n"
         "\t-g Grid spacing in angstroms (0.5)\n"
         "\t-b Additional perimeter to search for ion locations (10.0)\n"
         "\t-t ion name (SOD)\n"
         "\t-q ion charge (1.0)\n"
         "\t-p max_processors (1)\n");
}



