
#include "rmsdTools.h"


// Constructor
//
RmsdTools::RmsdTools(StructureAlignment* sa)
  : alignment(sa) {
  
  int structCount = alignment->getStructureCount();

  rmsdScores = new float*[structCount];
  for (int i=0; i<structCount; i++) {
    rmsdScores[i] = new float[structCount];
    for (int j=0; j<structCount; j++) {
      rmsdScores[i][j] = 0;
    }
  }
  rmsdPerRes = new float[alignment->getLength()];

  return;
}


// Destructor
//
RmsdTools::~RmsdTools() {
  
  delete rmsdScores;
  delete rmsdPerRes;

  return;
}


// rmsd
//
int RmsdTools::rmsd() {

  Coordinate3D* coord1 = 0;
  Coordinate3D* coord2 = 0;
  int structCount = alignment->getStructureCount();
  float distance = 0.0;

  float** norms = new float*[structCount];
  for (int i=0; i<structCount; i++) {
    norms[i] = new float[structCount];
    for (int j=0; j<structCount; j++) {
      norms[i][j] = 0;
    }
  }
  
  for (int row1=0; row1<structCount; row1++) {
    for (int row2=row1+1; row2<structCount; row2++) {
      for (int col=0; col<alignment->getLength(); col++) {
	if ( !(alignment->getAlphabet()->isGap(alignment->getSymbol(row1,col))) &&
	     !(alignment->getAlphabet()->isGap(alignment->getSymbol(row2,col))) ) {
	  coord1 = alignment->getCoordinate(row1,col);
	  coord2 = alignment->getCoordinate(row2,col);
	  distance = coord1->getDistanceTo(coord2);
	  printf("distance: %f \n",distance);
	  rmsdScores[row1][row2] += pow(distance,2);
	  rmsdScores[row2][row1] += pow(distance,2);
	  norms[row1][row2]++;
	  norms[row2][row1]++;
	}
      }
    }
  }

  for (int i=0; i<structCount; i++) {
    for (int j=0; j<structCount; j++) {
      if (i != j) {
	if (norms[i][j] > 0) {
	  rmsdScores[i][j] /= norms[i][j];
	  rmsdScores[i][j] = sqrt(rmsdScores[i][j]);
	}
	else {
	  rmsdScores[i][j] = 0;
	  printf("Error - RmsdTools::rmsd, divide by zero\n");
	  printf("   norms[%d][%d]\n",i,j);
	}
      }
    }
  }

  delete norms;

  return 1;
}


// rmsdPerResidue
//   Calculate RMSD between aligned residues of two structures;
//   -1 if there is at least one gap in the column
int RmsdTools::rmsdPerResidue(int struct1, int struct2) {

  if ( struct1 < 0 || struct2 < 0 ||
       struct1 >= alignment->getStructureCount() ||
       struct2 >= alignment->getStructureCount() ) {
    return 0;
  }

  Coordinate3D* coord1 = 0;
  Coordinate3D* coord2 = 0;

  for (int col=0; col<alignment->getLength(); col++) {
    if ( !(alignment->getAlphabet()->isGap(alignment->getSymbol(struct1,col))) &&
	 !(alignment->getAlphabet()->isGap(alignment->getSymbol(struct2,col))) ) {
      coord1 = alignment->getCoordinate(struct1,col);
      coord2 = alignment->getCoordinate(struct2,col);
      //printf("%d coord1: %s (%f,%f,%f)\n",col,alignment->getSymbol(struct1,col)->getThree(),coord1->getX(),coord1->getY(),coord1->getZ());
      //printf("%d coord2: %s (%f,%f,%f)\n",col,alignment->getSymbol(struct2,col)->getThree(),coord2->getX(),coord2->getY(),coord2->getZ());

      rmsdPerRes[col] = coord1->getDistanceTo(coord2);
    }
    else {
      //coord1 = alignment->getCoordinate(struct1,col);
      //coord2 = alignment->getCoordinate(struct2,col);
      //printf("%d coord2: %s (%f,%f,%f)\n",col,alignment->getSymbol(struct2,col)->getThree(),coord2->getX(),coord2->getY(),coord2->getZ());
      rmsdPerRes[col] = -1;
    }
  }

  return 1;
}


// printRmsd
//
int RmsdTools::printRmsd(FILE* outfile) {

  if (rmsdScores == 0) {
    printf("Error: RmsdTools::printRmsd\n");
    return 0;
  }

  int structCount = alignment->getStructureCount();
  
  for (int i=0; i<structCount; i++) {
    int j=0;
    fprintf(outfile,"%6.4f", rmsdScores[i][j]);
    for (j=1; j<structCount; j++) {
      fprintf(outfile," %6.4f", rmsdScores[i][j]);
    }
    fprintf(outfile,"\n");
  }

  return 1;
}


// printRmsdPerResidue
//
int RmsdTools::printRmsdPerResidue(FILE* outfile) {

  if (rmsdPerRes == 0) {
    printf("Error: RmsdTools::printRmsdPerResidue\n");
    return 0;
  }

  /* OLD VERSION - ALL ON ONE LINE
  for (int i=0; i<alignment->getLength(); i++) {
    fprintf(outfile,"%6.4f ", rmsdPerRes[i]);
  }
  fprintf(outfile,"\n");
  */

  for (int i=0; i<alignment->getLength(); i++) {
    fprintf(outfile,"%d %6.4f\n", i, rmsdPerRes[i]);
  }

  return 1;
}
