/* volmap.C - This file contains the initialization and manipulation 
   routines for the VolMap class.
*/

#include <math.h>
#include <vector>

#include "volmap.h"
#include "vec.h"

using namespace std;



/* OPS */

// This is just a way to do PMF and regular manipulations sharing the same
// code. 

class Operation {
public:
  virtual char  *name() = 0;
  virtual double ConvertValue(double) = 0;
  virtual double ConvertAverage(double) = 0;
};

class RegOpType : public Operation {
public:
  char  *name() {return "regular";}
  double ConvertValue(double val) {return val;}
  double ConvertAverage(double avg) {return avg;}
};

class PMFOpType : public Operation {
public:
  char  *name() {return "PMF";}
  double ConvertValue(double val) {return exp(-val);}
  double ConvertAverage(double avg) {
    double val = -log(avg);
  //  if (val != val || val > 150.) val = 150.;
    return val;
  }
};


Operation *RegOps=NULL;
Operation *PMFOps=NULL;

Operation *GetOps(int optype) {
  if (optype == PMF) {
    if (!PMFOps) PMFOps = new PMFOpType;
    return PMFOps;
  }
  if (!RegOps) RegOps = new RegOpType;
  return RegOps;
}


/* INITIALIZATION */

// Constructor
VolMap::VolMap() {
  xsize = 0;
  ysize = 0;
  zsize = 0;
  data = NULL;
  weight = 1.;
}



// Clone Constructor
VolMap::VolMap(const VolMap *src) {
  VolMap();
  clone(src);
}



/// Destructor
VolMap::~VolMap() {
  if (data) delete[] data;
}



// Clone
void VolMap::clone(const VolMap *src) {
  vcopy (origin, src->origin);
  vcopy (xaxis, src->xaxis);
  vcopy (yaxis, src->yaxis);
  vcopy (zaxis, src->zaxis);
  vcopy (xdelta, src->xdelta);
  vcopy (ydelta, src->ydelta);
  vcopy (zdelta, src->zdelta);
  xsize = src->xsize;
  ysize = src->ysize;
  zsize = src->zsize;  
 
  weight = src->weight;
  
  int gridsize = xsize*ysize*zsize;
  if (data) delete[] data;
  data = new double[gridsize];
  memcpy(data, src->data, gridsize*sizeof(double));
  
  refname = "(no name)";
}



/// Zero the data array
void VolMap::zero() {
  int gridsize = xsize*ysize*zsize;
  memset(data, 0, gridsize*sizeof(double));
}




/* UNARY OPERATIONS */


// Convert from PMF to density (exp(-PMF))
void VolMap::convert_pmf_to_density() {  
  printf("%s: converting from PMF to density, i.e. exp(-X)\n", refname);
  int n;
  int gridsize = xsize*ysize*zsize;
  for (n=0; n<gridsize; n++) data[n] = exp(-data[n]);
}



// Convert from density to PMF (-log(dens))
void VolMap::convert_density_to_pmf() {  
  printf("%s: converting from density to PMF, i.e. -log(X)\n", refname);
  int n;
  int gridsize = xsize*ysize*zsize;
  for (n=0; n<gridsize; n++) {
    if (data[n] < 0.000001)
      data[n] = -log(data[n]);
      if (data[n]>150.f || data[n] != data[n]) data[n] = 150.f;
    else
      data[n] = 150.f;
  }
}



// Trim/reduce each side of the volmap's grid
void VolMap::trim(int trimxm, int trimxp, int trimym, int trimyp, int trimzm, int trimzp) {  
  printf("%s: trimming map by x:%d %d y:%d %d z:%d %d\n", refname, trimxm, trimxp, trimym, trimyp, trimzm, trimzp);
  
  int gx, gy, gz;
  int xsize_new = max(0, xsize - trimxp - trimxm);
  int ysize_new = max(0, ysize - trimyp - trimym); 
  int zsize_new = max(0, zsize - trimzp - trimzm); 
  
  double *data_new = new double[xsize_new*ysize_new*zsize_new];
  
  for (gx=0; gx<xsize_new; gx++)
  for (gy=0; gy<ysize_new; gy++)
  for (gz=0; gz<zsize_new; gz++) {
    int n_new = gx + gy*xsize_new + gz*xsize_new*ysize_new;
    data_new[n_new] = data[(gx+trimxm) + (gy+trimym)*xsize + (gz+trimzm)*xsize*ysize];
  }
  
  delete data;
  data = data_new;
  
  xsize = xsize_new;
  ysize = ysize_new;
  zsize = zsize_new;
  double scaling_factor = (double) xsize_new/xsize;
  vscale(xaxis, scaling_factor);
  scaling_factor = (double) ysize_new/ysize;
  vscale(yaxis, scaling_factor);
  scaling_factor = (double) zsize_new/zsize;
  vscale(zaxis, scaling_factor);
  
  vaddscaledto(origin, trimxm, xdelta);
  vaddscaledto(origin, trimym, ydelta);
  vaddscaledto(origin, trimzm, zdelta);
}



// Downsample grid size by factor of 2
void VolMap::downsample(int optype) {
  Operation *ops = GetOps(optype);
  printf("%s: downsampling by x2 (%s)\n", refname, ops->name());
  
  int gx, gy, gz, j;
  int xsize_new = xsize/2;
  int ysize_new = ysize/2;
  int zsize_new = zsize/2;
  double *data_new = new double[xsize_new*ysize_new*zsize_new];
  
  int index_shift[8] = {0, 1, xsize, xsize+1, xsize*ysize, xsize*ysize + 1, xsize*ysize + xsize, xsize*ysize + xsize + 1};
  
  for (gx=0; gx<xsize_new; gx++)
  for (gy=0; gy<ysize_new; gy++)
  for (gz=0; gz<zsize_new; gz++) {
    int n_new = gx + gy*xsize_new + gz*xsize_new*ysize_new;
    int n = 2*(gx + gy*xsize + gz*xsize*ysize);
    double Z=0.;
    for (j=0; j<8; j++) Z += ops->ConvertValue(data[n+index_shift[j]]);
    data_new[n_new] = ops->ConvertAverage(Z/8.);
  }
  
  xsize = xsize_new;
  ysize = ysize_new;
  zsize = zsize_new;
  vscale(xdelta, 2.);
  vscale(ydelta, 2.);
  vscale(zdelta, 2.);
  double scaling_factor = 0.5*(xsize)/(xsize/2);
  vscale(xaxis, scaling_factor);
  scaling_factor = 0.5*(ysize)/(ysize/2);
  vscale(yaxis, scaling_factor);
  scaling_factor = 0.5*(zsize)/(zsize/2);
  vscale(zaxis, scaling_factor);
  
  vaddscaledto(origin, 0.25, xdelta);
  vaddscaledto(origin, 0.25, ydelta);
  vaddscaledto(origin, 0.25, zdelta);
      
  delete[] data;
  data = data_new;
}




void VolMap::collapse_onto_z(int optype) {
  Operation *ops = GetOps(optype);
  printf("%s: writing z-projection (%s) into file collapse.dat\n", refname, ops->name());
    
  int gx, gy, gz;
  double projection;
  FILE *fout = fopen("collapse.dat", "w");
      
  for (gz=0; gz<zsize; gz++) {
    projection = 0.;
    
    for (gx=0; gx<xsize; gx++)
    for (gy=0; gy<ysize; gy++) {
      projection += ops->ConvertValue(data[gx + gy*xsize + gz*xsize*ysize]);
    }
    
    projection = projection/(xsize*ysize);
    fprintf(fout, "%g %g\n", origin[2]+gz*zdelta[2], ops->ConvertAverage(projection));
  }
   
  fclose(fout);

}



// Average the map over N rotations of itself
void VolMap::average_over_rotations(int optype) {
  Operation *ops = GetOps(optype);
  
  // Hard-code the rotation axis, the rotation center, and the number of rotations:
  const double R_center[3] = {0., 0., 0.};
  const double R_axis[3] = {0., 0., 1.};
  const int num_rot = 4;

  double rot_incr = 2.*M_PI/(double) num_rot;
  printf("%s: averaging the PMF over %d rotations (%s)\n", refname, num_rot, ops->name());

  int gridsize = xsize*ysize*zsize;
  double *data_new = new double[gridsize];
  memset(data_new, 0, gridsize*sizeof(double));
  int *data_count = new int[gridsize];
  memset(data_count, 0, gridsize*sizeof(int));
  
  double x, y, z;
  double xo, yo, zo;
  int rot;
  for (rot=0; rot<num_rot; rot++) {
    double angle = rot_incr*rot;  
    double cosA = cos(angle);
    double sinA = sin(angle);
    double t = 1.-cosA;
    for (int n=0; n<gridsize; n++) {
      index_to_coord(n, xo, yo, zo);
      xo -= R_center[0];
      yo -= R_center[1];
      zo -= R_center[2];
      x = xo*(t*R_axis[0]*R_axis[0] + cosA) + yo*(t*R_axis[0]*R_axis[1]+sinA*R_axis[2]) + zo*(t*R_axis[0]*R_axis[2]-sinA*R_axis[1]) + R_center[0];
      y = xo*(t*R_axis[0]*R_axis[1]-sinA*R_axis[2]) + yo*(t*R_axis[1]*R_axis[1] + cosA) + zo*(t*R_axis[1]*R_axis[2]+sinA*R_axis[0]) + R_center[1];
      z = xo*(t*R_axis[0]*R_axis[2]+sinA*R_axis[1]) + yo*(t*R_axis[1]*R_axis[2]-sinA*R_axis[0]) + zo*(t*R_axis[2]*R_axis[2] + cosA) + R_center[2];
      
      double val = voxel_value_interpolate_from_coord(x, y, z, optype);
      
      if (val == val) {
        data_new[n] += ops->ConvertValue(val);
        data_count[n]++;
      }
    }
  }
  
  
  for (int n=0; n<gridsize; n++) {
    double val = ops->ConvertAverage(data_new[n]/data_count[n]);
    
    if (optype == PMF) {
      if (val == val && val < 150.) data_new[n] = val;
      else data_new[n] = 150.;
    }
    else
      data_new[n] = val; 
  }
  
  delete[] data;
  delete[] data_count;
  data = data_new;
}



//Gaussian blurring (as a 3D convolution), but the kernel can easily be changed to something else
//Right now, only works if resolution is the same in all map dimensions
void VolMap::smooth(double radius, int optype) {
  if (!radius) return;
  
  Operation *ops = GetOps(optype);
  printf("%s: Gaussion blur filter (%s); radius = %g \n", refname, ops->name(), radius);
 
  double delta = xdelta[0];
  int step = (int)(3.*radius/delta); // size of gaussian convolution
  if (!step) return;
  
  int gridsize = xsize*ysize*zsize;
  double *data_new = new double[gridsize];
  memset(data_new, 0, gridsize*sizeof(double));
  
  // Build convolution kernel
  int convsize = 2*step+1;
  double *conv = new double[convsize*convsize*convsize];
  memset(conv, 0, convsize*convsize*convsize*sizeof(double));
  
  double r2, norm=0.;
  int cx, cy, cz; 
  for (cz=0; cz<convsize; cz++)
  for (cy=0; cy<convsize; cy++)
  for (cx=0; cx<convsize; cx++) {
    r2 = delta*delta*((cx-step)*(cx-step)+(cy-step)*(cy-step)+(cz-step)*(cz-step));
    conv[cx + cy*convsize + cz*convsize*convsize] = exp(-0.5*r2/(radius*radius)); 
    norm += conv[cx + cy*convsize + cz*convsize*convsize];
  }
  
  // Normalize...
  int n;
  for (n=0; n<convsize*convsize*convsize; n++) {
    conv[n] = conv[n]/norm;
  }
 
  // Apply convolution   
  for (n=0; n<gridsize; n++) data[n] = ops->ConvertValue(data[n]);  
  
  int gx, gy, gz, hx, hy, hz; 
  for (gz=0; gz<zsize; gz++)
  for (gy=0; gy<ysize; gy++) 
  for (gx=0; gx<xsize; gx++)
  for (cz=0; cz<convsize; cz++)
  for (cy=0; cy<convsize; cy++)
  for (cx=0; cx<convsize; cx++) {
    hx=gx+cx-step;
    hy=gy+cy-step;
    hz=gz+cz-step;
    if (hx < 0 || hx >= xsize || hy < 0 || hy >= ysize || hz < 0 || hz >= zsize) {
      continue;
    }

    data_new[gx + gy*xsize + gz*xsize*ysize] += data[hx + hy*xsize + hz*xsize*ysize]*conv[cx + cy*convsize + cz*convsize*convsize];  
  }
  
  for (n=0; n<gridsize; n++) data_new[n] = ops->ConvertAverage(data_new[n]);  
  
  delete[] data;
  data = data_new;
}





void VolMap::total_occupancy() {
    
  int gx, gy, gz;
  double val;
  double occup = 0.;
  int count = 0;
  
  for (gz=0; gz<zsize; gz++)
  for (gx=0; gx<xsize; gx++)
  for (gy=0; gy<ysize; gy++) {
    val = data[gx + gy*xsize + gz*xsize*ysize];
    occup += exp(-val);
    if (val) count++;
  }
  
  double factor = 6.0221e23/22.4e27;  // 1mol/22.4L in particles/A^3 at STP
  occup *= factor;
  
  printf("\nCOUNT: At 1atm, the occupancy of the map is: %g particles\n", occup);
  printf("\nCOUNT: In air, the occupancy of the map is: %g O2\n", occup*0.2);
  printf("\nCOUNT: Non-zero cell count is: %d\n", count);
  printf("\nCOUNT: Occupancy in equiv. vacuum would be %g\n", factor*count);
  printf("\nCOUNT: Occupancy in equiv. air would be %g\n", factor*count*0.2);
}
    


void VolMap::print_stats() {  
  int gx, gy, gz;
  double D, E;

  double sum_D = 0.;  
  double sum_E = 0.; 
  double sum_E2 = 0.;  
  double sum_DE = 0.; 
  double sum_DE2 = 0.; 
  double sum_D2 = 0.; 
  
  double min_E = data[0]; 
  double max_E = data[0];  
  
  for (gx=0; gx<xsize; gx++)
  for (gy=0; gy<ysize; gy++) 
  for (gz=0; gz<zsize; gz++) {
    E = data[gx + gy*xsize + gz*xsize*ysize];
    D = exp(-E);
    sum_D  += D;
    sum_D2 += D*D;
    sum_DE += D*E;
    sum_DE2+= D*E*E;
    sum_E  += E;
    sum_E2 += E*E;

    if (E<min_E) min_E = E;
    if (E>max_E) max_E = E;
  }
  
  double N = xsize*ysize*zsize;
  double pmf = -log(sum_D/N);
  
//  double dev_D = sqrt(sum_D2/N - sum_D*sum_D/(N*N));
  double dev_E = sqrt(sum_E2/N - sum_E*sum_E/(N*N));
  
  printf("OUTPUT STATS:\n");
  //printf("  NAME:      %s\n", refname);
  printf("  WEIGHT:    %g\n", weight);
  printf("  AVERAGE:   %g\n", sum_E/N);
  printf("  STDEV:     %g\n", dev_E);
  printf("  MIN:       %g\n", min_E);
  printf("  MAX:       %g\n", max_E);
  printf("\n");
  printf("  PMF_AVG:   %g\n", pmf);

}







/* BINARY OPERATIONS */

//////////////////////////////////////////////////////////////


/// creates axes, bounding box and allocates data based on 
/// geometrical intersection of A and B
void VolMap::init_from_intersection(VolMap *mapA, VolMap *mapB) {
  int d;
  
  // Find intersection of A and B
  // The following has been verified for orthog. cells
  // (Does not work for non-orthog cells)
  
  for (d=0; d<3; d++) {
    origin[d] = max(mapA->origin[d], mapB->origin[d]);
    xaxis[d] = max(min(mapA->origin[d]+mapA->xaxis[d], mapB->origin[d]+mapB->xaxis[d]), origin[d]);
    yaxis[d] = max(min(mapA->origin[d]+mapA->yaxis[d], mapB->origin[d]+mapB->yaxis[d]), origin[d]);
    zaxis[d] = max(min(mapA->origin[d]+mapA->zaxis[d], mapB->origin[d]+mapB->zaxis[d]), origin[d]);
  }
    
  vsub(xaxis, xaxis, origin);
  vsub(yaxis, yaxis, origin);
  vsub(zaxis, zaxis, origin);
  
  xsize = (int) max(vdot(xaxis,mapA->xaxis)*mapA->xsize/vdot(mapA->xaxis,mapA->xaxis), \
                    vdot(xaxis,mapB->xaxis)*mapB->xsize/vdot(mapB->xaxis,mapB->xaxis));
  ysize = (int) max(vdot(yaxis,mapA->yaxis)*mapA->ysize/vdot(mapA->yaxis,mapA->yaxis), \
                    vdot(yaxis,mapB->yaxis)*mapB->ysize/vdot(mapB->yaxis,mapB->yaxis));
  zsize = (int) max(vdot(zaxis,mapA->zaxis)*mapA->zsize/vdot(mapA->zaxis,mapA->zaxis), \
                    vdot(zaxis,mapB->zaxis)*mapB->zsize/vdot(mapB->zaxis,mapB->zaxis));
    
  for (d=0; d<3; d++) {
    xdelta[d] = xaxis[d]/(xsize-1);
    ydelta[d] = yaxis[d]/(ysize-1);
    zdelta[d] = zaxis[d]/(zsize-1);
  }
  
  // Create map...
  if (data) delete[] data;
  data = new double[xsize*ysize*zsize];
}




/// creates axes, bounding box and allocates data based on 
/// geometrical union of A and B
void VolMap::init_from_union(VolMap *mapA, VolMap *mapB) {
  int d;
  
  // Find union of A and B
  // The following has been verified for orthog. cells
  // (Does not work for non-orthog cells)
  
  vset(xaxis, 0., 0., 0.);
  vset(yaxis, 0., 0., 0.);
  vset(zaxis, 0., 0., 0.);
  
  for (d=0; d<3; d++) {
    origin[d] = min(mapA->origin[d], mapB->origin[d]);
  }
  d=0;
  xaxis[d] = max(max(mapA->origin[d]+mapA->xaxis[d], mapB->origin[d]+mapB->xaxis[d]), origin[d]);
  d=1;
  yaxis[d] = max(max(mapA->origin[d]+mapA->yaxis[d], mapB->origin[d]+mapB->yaxis[d]), origin[d]);
  d=2;
  zaxis[d] = max(max(mapA->origin[d]+mapA->zaxis[d], mapB->origin[d]+mapB->zaxis[d]), origin[d]);
  
  xaxis[0] -= origin[0];
  yaxis[1] -= origin[1];
  zaxis[2] -= origin[2];
  
  xsize = (int) max(vdot(xaxis,mapA->xaxis)*mapA->xsize/vdot(mapA->xaxis,mapA->xaxis), \
                    vdot(xaxis,mapB->xaxis)*mapB->xsize/vdot(mapB->xaxis,mapB->xaxis));
  ysize = (int) max(vdot(yaxis,mapA->yaxis)*mapA->ysize/vdot(mapA->yaxis,mapA->yaxis), \
                    vdot(yaxis,mapB->yaxis)*mapB->ysize/vdot(mapB->yaxis,mapB->yaxis));
  zsize = (int) max(vdot(zaxis,mapA->zaxis)*mapA->zsize/vdot(mapA->zaxis,mapA->zaxis), \
                    vdot(zaxis,mapB->zaxis)*mapB->zsize/vdot(mapB->zaxis,mapB->zaxis));
  
  for (d=0; d<3; d++) {
    xdelta[d] = xaxis[d]/(xsize-1);
    ydelta[d] = yaxis[d]/(ysize-1);
    zdelta[d] = zaxis[d]/(zsize-1);
  }
  
  // Create map...
  data = new double[xsize*ysize*zsize];
}



void VolMap::init_from_identity(VolMap *mapA) {
  int d;
  

  vcopy(origin, mapA->origin);
  vcopy(xaxis, mapA->xaxis);
  vcopy(yaxis, mapA->yaxis);
  vcopy(zaxis, mapA->zaxis); 
  
  xsize = mapA->xsize;
  ysize = mapA->ysize;
  zsize = mapA->zsize;
    
  for (d=0; d<3; d++) {
    xdelta[d] = xaxis[d]/(xsize-1);
    ydelta[d] = yaxis[d]/(ysize-1);
    zdelta[d] = zaxis[d]/(zsize-1);
  }
  
  // Create map...
  data = new double[xsize*ysize*zsize];
}




void VolMap::create_sum(VolMap *mapA, VolMap *mapB) {
  printf("%s: result of sum operation %s + %s\n", refname, mapA->refname, mapB->refname);
    
  init_from_intersection(mapA, mapB);
  
  // adding maps by spatial coords is slower than doing it directly, but allows for 
  // precisely subtracting unaligned maps, and/or maps of different resolutions
  for (int gx=0; gx<xsize; gx++)
  for (int gy=0; gy<ysize; gy++)
  for (int gz=0; gz<zsize; gz++) {
    double x = origin[0] + (gx)*xdelta[0]; 
    double y = origin[1] + (gy)*ydelta[1];
    double z = origin[2] + (gz)*zdelta[2];
    data[gz*xsize*ysize + gy*xsize + gx] = mapA->voxel_value_interpolate_from_coord(x,y,z) + \
         mapB->voxel_value_interpolate_from_coord(x,y,z);
  }
}



void VolMap::create_multiply(VolMap *mapA, VolMap *mapB) {
  printf("%s: result of multiply operation %s + %s\n", refname, mapA->refname, mapB->refname);
    
  init_from_intersection(mapA, mapB);
  
  // multiplying maps by spatial coords is slower than doing it directly, but allows for 
  // precisely subtracting unaligned maps, and/or maps of different resolutions
  for (int gx=0; gx<xsize; gx++)
  for (int gy=0; gy<ysize; gy++)
  for (int gz=0; gz<zsize; gz++) {
    double x = origin[0] + (gx)*xdelta[0]; 
    double y = origin[1] + (gy)*ydelta[1];
    double z = origin[2] + (gz)*zdelta[2];
    data[gz*xsize*ysize + gy*xsize + gx] = mapA->voxel_value_interpolate_from_coord(x,y,z) * \
         mapB->voxel_value_interpolate_from_coord(x,y,z);
  }
}




// Weighted average of N maps
void VolMap::create_combine(char **files, int numfiles, int optype) {
  Operation *ops = GetOps(optype);
  printf("%s: Weighted average (%s)\n", refname, ops->name());

  if (numfiles < 2) {
    printf("Error: need at least 2 files for combine operation!\n");
    exit(1);
  }
  
  VolMap *potsave = new VolMap();
  potsave->refname = "INPUT";
  
  VolMap *pot = new VolMap();
  pot->refname = "INPUT";
  
  int err = potsave->read_dx_file(files[0]);
  if (err) {
    printf("Error: cannot read file!\n");
    exit(1);
  }
  printf("INPUT: weight = %g\n", potsave->weight);
  
  for (int i=1; i < numfiles; i++) {
    err = pot->read_dx_file(files[i]);
    if (err) {
      printf("Error: cannot read file!\n");
      exit(1);
    }
    printf("INPUT: weight = %g\n", pot->weight);
    
    init_from_intersection(potsave, pot);
  
    for (int gx=0; gx<xsize; gx++)
    for (int gy=0; gy<ysize; gy++)
    for (int gz=0; gz<zsize; gz++) {
      double x = origin[0] + (gx)*xdelta[0]; 
      double y = origin[1] + (gy)*ydelta[1];
      double z = origin[2] + (gz)*zdelta[2]; 
      data[gz*xsize*ysize + gy*xsize + gx] = \
          ops->ConvertAverage((potsave->weight * ops->ConvertValue(potsave->voxel_value_from_coord(x,y,z)) + \
          pot->weight * ops->ConvertValue(pot->voxel_value_from_coord(x,y,z))) /(potsave->weight+pot->weight));
    }
  
    weight = potsave->weight + pot->weight;
    potsave->clone(this);
  }  

  delete potsave;
  delete pot;
}



void VolMap::create_diff(VolMap *mapA, VolMap *mapB) {
  printf("%s: result of difference operation %s - %s\n", refname, mapA->refname, mapB->refname);
      
    init_from_intersection(mapA, mapB);
  
  // subtracting maps by spatial coords is slower than doing it directly, but allows for 
  // precisely subtracting unaligned maps, and/or maps of different resolutions
  for (int gx=0; gx<xsize; gx++)
  for (int gy=0; gy<ysize; gy++)
  for (int gz=0; gz<zsize; gz++) {
    double x = origin[0] + (gx)*xdelta[0]; //XXX check this...
    double y = origin[1] + (gy)*ydelta[1];
    double z = origin[2] + (gz)*zdelta[2];
    double valB = mapB->voxel_value_interpolate_from_coord(x,y,z);
    if (valB == valB)
      data[gz*xsize*ysize + gy*xsize + gx] = mapA->voxel_value_interpolate_from_coord(x,y,z) - valB;
    else
      data[gz*xsize*ysize + gy*xsize + gx] = mapA->voxel_value_interpolate_from_coord(x,y,z);
  }
}









/* VOXELS */  

/// return voxel, after safely clamping index to valid range
double VolMap::voxel_value_safe(int x, int y, int z) const {
  int xx, yy, zz; 
  xx = (x > 0) ? ((x < xsize) ? x : xsize-1) : 0;
  yy = (y > 0) ? ((y < ysize) ? y : ysize-1) : 0;
  zz = (z > 0) ? ((z < zsize) ? z : zsize-1) : 0;
  int index = zz*xsize*ysize + yy*xsize + xx;
  return data[index];
}


/// return interpolated value from 8 nearest neighbor voxels
double VolMap::voxel_value_interpolate(double xv, double yv, double zv) const {
  int x = (int) xv;
  int y = (int) yv;
  int z = (int) zv;
  double xf = xv - x;
  double yf = yv - y;
  double zf = zv - z;
  double xlerps[4];
  double ylerps[2];
  double tmp;

  tmp = voxel_value_safe(x, y, z);
  xlerps[0] = tmp + xf*(voxel_value_safe(x+1, y, z) - tmp);

  tmp = voxel_value_safe(x, y+1, z);
  xlerps[1] = tmp + xf*(voxel_value_safe(x+1, y+1, z) - tmp);

  tmp = voxel_value_safe(x, y, z+1);
  xlerps[2] = tmp + xf*(voxel_value_safe(x+1, y, z+1) - tmp);

  tmp = voxel_value_safe(x, y+1, z+1);
  xlerps[3] = tmp + xf*(voxel_value_safe(x+1, y+1, z+1) - tmp);

  ylerps[0] = xlerps[0] + yf*(xlerps[1] - xlerps[0]);
  ylerps[1] = xlerps[2] + yf*(xlerps[3] - xlerps[2]);

  return ylerps[0] + zf*(ylerps[1] - ylerps[0]);
}


/// return interpolated value from 8 nearest neighbor voxels
double VolMap::voxel_value_interpolate_pmf_exp(double xv, double yv, double zv) const {
  int x = (int) xv;
  int y = (int) yv;
  int z = (int) zv;
  double xf = xv - x;
  double yf = yv - y;
  double zf = zv - z;
  double xlerps[4];
  double ylerps[2];
  double tmp;

  tmp = exp(-voxel_value_safe(x, y, z));
  xlerps[0] = tmp + xf*(exp(-voxel_value_safe(x+1, y, z)) - tmp);

  tmp = exp(-voxel_value_safe(x, y+1, z));
  xlerps[1] = tmp + xf*(exp(-voxel_value_safe(x+1, y+1, z)) - tmp);

  tmp = exp(-voxel_value_safe(x, y, z+1));
  xlerps[2] = tmp + xf*(exp(-voxel_value_safe(x+1, y, z+1)) - tmp);

  tmp = exp(-voxel_value_safe(x, y+1, z+1));
  xlerps[3] = tmp + xf*(exp(-voxel_value_safe(x+1, y+1, z+1)) - tmp);

  ylerps[0] = xlerps[0] + yf*(xlerps[1] - xlerps[0]);
  ylerps[1] = xlerps[2] + yf*(xlerps[3] - xlerps[2]);

  return -log(ylerps[0] + zf*(ylerps[1] - ylerps[0]));
}


int VolMap::coord_to_index(double x, double y, double z) const {
  x -= origin[0];
  y -= origin[1];
  z -= origin[2];
  // XXX Needs to be fixed for non-orthog cells (subtract out projected component every step)
  int gx = int((x*xaxis[0] + y*xaxis[1] + z*xaxis[2])/vnorm(xaxis));
  int gy = int((x*yaxis[0] + y*yaxis[1] + z*yaxis[2])/vnorm(yaxis));
  int gz = int((x*zaxis[0] + y*zaxis[1] + z*zaxis[2])/vnorm(zaxis));
  return (gx + gy*xsize + gz*ysize*xsize);
}


void VolMap::index_to_coord(int index, double &x, double &y, double &z) const {
  x = origin[0] + xdelta[0]*(index%xsize);
  y = origin[1] + ydelta[1]*((index/xsize)%ysize);
  z = origin[2] + zdelta[2]*(index/(xsize*ysize));
}


  
/// return value of voxel, based on atomic coords.
/// XXX need to account for non-orthog. cells
double VolMap::voxel_value_from_coord(double xpos, double ypos, double zpos) const {
  double min_coord[3];
  for (int i=0; i<3; i++) min_coord[i] = origin[i] - 0.5*(xdelta[i] + ydelta[i] + zdelta[i]);
  xpos -= min_coord[0];
  ypos -= min_coord[1];
  zpos -= min_coord[2];
  int gx = (int) (xpos/xdelta[0]); // XXX this is wrong for non-orthog cells.
  int gy = (int) (ypos/ydelta[1]);
  int gz = (int) (zpos/zdelta[2]);
  if (gx < 0 || gx >= xsize) return kNAN;
  if (gy < 0 || gy >= ysize) return kNAN;
  if (gz < 0 || gz >= zsize) return kNAN;
  return data[gz*xsize*ysize + gy*xsize + gx];
}


/// return interpolated value of voxel, based on atomic coords.
/// XXX need to account for non-orthog. cells
double VolMap::voxel_value_interpolate_from_coord(double xpos, double ypos, double zpos, int optype) const {
  xpos = (xpos-origin[0])/xdelta[0];
  ypos = (ypos-origin[1])/ydelta[1];
  zpos = (zpos-origin[2])/zdelta[2];
  int gx = (int) xpos; // XXX this is wrong for non-orthog cells.
  int gy = (int) ypos;
  int gz = (int) zpos;
  if (gx < 0 || gx >= xsize) return kNAN;
  if (gy < 0 || gy >= ysize) return kNAN;
  if (gz < 0 || gz >= zsize) return kNAN;
  
  if (optype==PMF)
    return voxel_value_interpolate_pmf_exp(xpos, ypos, zpos);
  else
    return voxel_value_interpolate(xpos, ypos, zpos);
}

