/*****************************************************************************
*
*            (C) Copyright 2005 The Board of Trustees of the
*                        University of Illinois
*                         All Rights Reserved
*
******************************************************************************/


#include <stdio.h>
#include <math.h>
#include "symbol.h"
#include "sequence.h"
#include "alignedSequence.h"
#include "sequenceAlignment.h"
#include "sequenceQR.h"

// Constructor
SequenceQR::SequenceQR(SequenceAlignment *alignment, float identityCutoff, int preserveCount, int performGapScaling, float gapScaleParameter) {

    this->alignment = alignment;
    this->identityCutoff = identityCutoff;
    this->preserveCount = preserveCount;
    this->performGapScaling = performGapScaling;
    this->gapScaleParameter = gapScaleParameter;
    cMi = alignment->getLength();
    cMj = 24;
    cMk = alignment->getSequenceCount();
    binary = 0;
    
    float householderTime = 0;
    float permuteColumnsTime = 0;
    
    //Create a matrix containing the data representing this alignment.
    matrix = new float**[cMi];
    for (int i=0; i<cMi; i++) {
    
        matrix[i] = new float*[cMj];
        for (int j=0; j<cMj; j++) {
        
            matrix[i][j] = new float[cMk];
            for (int k=0; k<cMk; k++) {
            
                //Get the sequence we are working with.
                AlignedSequence* sequence = alignment->getSequence(k);
                
                //Fill in the value for the matrix element.
                switch (j) {
                    case 0:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'A')?1:0;
                        break;
                    case 1:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'B')?1:0;
                        break;
                    case 2:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'C')?1:0;
                        break;
                    case 3:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'D')?1:0;
                        break;
                    case 4:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'E')?1:0;
                        break;
                    case 5:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'F')?1:0;
                        break;
                    case 6:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'G')?1:0;
                        break;
                    case 7:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'H')?1:0;
                        break;
                    case 8:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'I')?1:0;
                        break;
                    case 9:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'K')?1:0;
                        break;
                    case 10:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'L')?1:0;
                        break;
                    case 11:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'M')?1:0;
                        break;
                    case 12:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'N')?1:0;
                        break;
                    case 13:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'P')?1:0;
                        break;
                    case 14:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'Q')?1:0;
                        break;
                    case 15:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'R')?1:0;
                        break;
                    case 16:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'S')?1:0;
                        break;
                    case 17:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'T')?1:0;
                        break;
                    case 18:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'V')?1:0;
                        break;
                    case 19:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'W')?1:0;
                        break;
                    case 20:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'X')?1:0;
                        break;
                    case 21:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'Y')?1:0;
                        break;
                    case 22:
                        matrix[i][j][k] = (sequence->getSymbol(i)->getOne() == 'Z')?1:0;
                        break;
                    case 23:
                        matrix[i][j][k] = sequence->getAlphabet()->isGap(sequence->getSymbol(i))?1:0;
                        break;
                }
            }
        }
    }

    
    //Fill in the initial sequence ordering list.    
    columnList = new int[cMk];
    for (int k=0; k<cMk; k++) {
        columnList[k] = k;
    }
    
    //Scale the gap data.
    if (performGapScaling) {
        scaleGapData();
    }

    
    //Print out the matrix.
    /*
    for (int k=0; k<cMk; k++)
    {
        printf("Sequence %d\n", k);
        for (int i=0; i<cMi; i++)
        {
            for (int j=0; j<cMj; j++)
            {
                printf("%1.2f ", matrix[i][j][columnList[k]]);
            }
            printf("\n");
        }
        printf("\n");
    }
    */
    
}


// Destructor
SequenceQR::~SequenceQR() {

    //Free any used memory.
    for (int i=0; i<cMi; i++) {
        for (int j=0; j<cMj; j++) {
            delete matrix[i][j];
        }
        delete matrix[i];
    }
    delete matrix;
    delete columnList;
}


// qrAlgorithm
//   Loop through the sequences, permuting the most linearly independent
//   sequence (n) to the front of the current submatrix, and perform
//   Householder transformations on the submatrices to zero out the
//   contributions of n
SequenceAlignment* SequenceQR::qr() {

    //Perform the QR factorization.
    int k=0;
    for (k=0; k<cMk; k++) {

        //Permute the columns and perform the householder transform.
        if (k >= preserveCount)
            permuteColumns(k);
        householder(k);
        
        //If we have exceeded the percent identity cutoff value, we are done.
	if (identityCutoff < 1.0 && isSequenceAboveIdentityCutoff(k)) break;

    }

    //Copy the profile into a new sequence alignment.
    int kMax = k;
    SequenceAlignment* profile = new SequenceAlignment(alignment->getLength(), alignment->getSequenceCount());
    for (k=0; k<kMax; k++) {
        profile->addSequence(alignment->getSequence(columnList[k]));
    }
    
    return profile;
}

// qrAlgorithm
//   Loop through the sequences, permuting the most linearly independent
//   sequence (n) to the front of the current submatrix, and perform
//   Householder transformations on the submatrices to zero out the
//   contributions of n
//
// Parameters: percent: the percent of sequences to get, out of the total number
//                      of sequences
//
SequenceAlignment* SequenceQR::qr(int percent) {
    //Perform the QR factorization.
    int k=0;
    if(percent < 0) {
	    percent = 0;
    }
    if(percent > 100) {
	    percent = 100;
    }
    int limit = (int)((percent/100.0f)*cMk);

    for (k=0; k<limit; k++) {

        //Permute the columns and perform the householder transform.
        if (k >= preserveCount) 
            permuteColumns(k);
        householder(k);
        
        //If we have exceeded the percent identity cutoff value, we are done.
	//if (identityCutoff < 1.0 && isSequenceAboveIdentityCutoff(k)) break;

    }

    //Copy the profile into a new sequence alignment.
    int kMax = k;
    SequenceAlignment* profile = new SequenceAlignment(alignment->getLength(), alignment->getSequenceCount());
    for (k=0; k<kMax; k++) {
        profile->addSequence(alignment->getSequence(columnList[k]));
    }
    
    return profile;
}


// householder
//
void SequenceQR::householder(int currentColumn) {

  int i,j,k;
  float sign, alpha, beta, gamma;
  float * hhVector;

  // Loop over coordinate dimensions (x,y,z,gap)
  for (j=0; j<cMj; j++) {
    
    // Compute Householder vector for current column
    k = currentColumn;
    alpha = 0;
    for (i=k; i<cMi; i++) {
      alpha += matrix[i][j][columnList[k]] * matrix[i][j][columnList[k]];
    }
    sign = (matrix[k][j][columnList[k]] >= 0) ? 1.0 : -1.0;
    alpha = -sign * sqrt(alpha);
    hhVector = new float[cMi];
    for (i=0; i<k; i++) {
      //hhVector[i] = -alpha;  // REMOVED 8/3
      hhVector[i] = 0;   // ADDED 8/3
    }
    hhVector[k] = matrix[k][j][columnList[k]] - alpha;
    for (i=k+1; i<cMi; i++) {
      //hhVector[i] = matrix[i][j][columnList[k]] - alpha;   // REMOVED 8/3
      // ADDED 8/3 {
      hhVector[i] = matrix[i][j][columnList[k]];
      //if (i==k) {
      //  hhVector[i] -= alpha;
      //}
      // } ADDED 8/3
    }

    // Get inner product of Householder vector with itself
    beta = 0;
    for (i=k; i<cMi; i++) {
      beta += hhVector[i] * hhVector[i];
    }
    
    // Apply transformation to remaining submatrix
    if (beta != 0) {
      //printf("In --- beta: %f\n", beta);
      for (; k<cMk; k++) {
	gamma = 0;
	for (i=0; i<cMi; i++) {
	  gamma += hhVector[i] * matrix[i][j][columnList[k]];
	}
	//printf("gamma: %f, (2*gamma)/beta: %f", gamma, (2*gamma)/beta);
	for (i=currentColumn; i<cMi; i++) {
	  //printf("matrix[%d][%d][%d]: %f\n", i,j,columnList[k], matrix[i][j][columnList[k]]);
	  //printf("((2*gamma)/beta) * hhVector[%d] = %f * %f = %f\n",i, (2*gamma)/beta, hhVector[i], ((2*gamma)/beta) * hhVector[i]);
	  matrix[i][j][columnList[k]] -= ((2*gamma)/beta) * hhVector[i];
	}
	//printf("\n");
      }
    }
  }
}


/**
 * This method moves the column with the max frobenius norm to the front of the current submatrix.
 *
 * @param   currentColumn   The column in the submatrix that should be filled with the max norm.
 */
int SequenceQR::permuteColumns(int currentColumn) {

    int maxCol = 0;
    
    //printf("	Starting permuteColumns(%i):", currentColumn);
    //If this is the first column, switch it with the column with least average percent identity.
    // Skip this step for the binary version
    if ((currentColumn == 0)) {

	    if(binary) {
		    return 0;
	    }
        float min = -1.0;
        
        for (int k1=0; k1<cMk; k1++) {
            float value = 0.0;
            for (int k2=0; k2 <cMk; k2++) {
                value += (alignment->getSequence(k1)->getPercentIdentity(alignment->getSequence(k2)));
            }
            
            //If this is the least percent identity we have yet encountered, save it.
            if (min < 0.0 || value < min) {
                min = value;
                maxCol = k1;
            }
        }
    }

    //Otherwise, use the frobenius norm to figure out which column to switch.
    else {
        float *norms = new float[cMk];
        float maxNorm = 0.0;

        int k=0;
        for (k=0; k<cMk; k++) {
            norms[k] = 0.0;
        }
        
        //Get the maxiumum norm.
        for (k=currentColumn; k<cMk; k++) {
            
            //Get frobenius norms for matrix.
            norms[k] = frobeniusNormByK(k, currentColumn);
            
            //If this is the largest norm, select this column.
            if (norms[k] >= maxNorm) {
                maxCol = k;
                maxNorm = norms[k];
            }
        }
	//Print the norms.
	/*
	printf("Norms for %d:\n", currentColumn);
	for (int k=0; k<cMk; k++) {
		printf("norms[%i] = %4.4f\n",k, norms[k]);
	}

	printf("\n");
	    printf("Max norm = %f\n\n", maxNorm);
	*/
	delete norms;
    }

    /*
    //Print the norms.
    printf("Norms for %d:", currentColumn);
    for (int k=0; k<cMk; k++) {
        printf(" %4.4f", norms[k]);
    }
    printf("\n");
    */    
    //printf("maxCol %i = %i\n", currentColumn, maxCol);
    //Switch the columns.
    int temp = columnList[maxCol];
    columnList[maxCol] = columnList[currentColumn];    
    columnList[currentColumn] = temp;
}
    
/**
 * This method gets whether the current column exceeds the percent identity threshold with any
 * of the previous columns.
 *
 * @param   currentColumn   The column in the matrix that should be checked.
 * @return  1 if the column does exceed the percent identity threshold, 0 if it does not.
 */
int SequenceQR::isSequenceAboveIdentityCutoff(int currentColumn) {
    
    //See if this column exceeds the percent identity with any of the previously selected columns.
    AlignedSequence* currentSequence = alignment->getSequence(columnList[currentColumn]);
    for (int k=0; k<currentColumn; k++) {
        AlignedSequence* sequence = alignment->getSequence(columnList[k]);
	//printf("Sequence %i: = %s\n;percentIdent=%f\n", columnList[k], sequence->toString(), currentSequence->getPercentIdentity(sequence));
	//printf("Sequence %i: = %f\n", k, currentSequence->getPercentIdentity(sequence));
        if (currentSequence->getPercentIdentity(sequence) >= identityCutoff) {
            return 1;
        }
    }
    
    return 0;
}


// frobeniusNormSeq
//   Get the frobenius norm for the matrix corresponding
//   to the data for one sequence
//   frobeniusNorm(A) = sqrt( sum( all Aij ) );
float SequenceQR::frobeniusNormByK(int k, int currentRow) {

    float fNorm = 0;
    
    for (int i=currentRow; i<cMi; i++) {
        for (int j=0; j<cMj; j++) {
            fNorm += matrix[i][j][columnList[k]] * matrix[i][j][columnList[k]];
        }
    }
    
    return sqrt(fNorm);
}


// frobeniusNormCoord
//
float SequenceQR::frobeniusNormByJ(int j) {

    float fNorm = 0;

    for (int i=0; i<cMi; i++) {
        for (int k=0; k<cMk; k++) {
            fNorm += matrix[i][j][columnList[k]] * matrix[i][j][columnList[k]];
        }
    }
    
    return sqrt(fNorm);
}


// scaleGapData
//   Scale the gap matrix elements to appropriate values so that
//   the QR algorithm is not biased towards or against the gaps.
//   scale*fNorm(G) = fNorm(X) + fNorm(Y) + fNorm(Z)
void SequenceQR::scaleGapData() {

    //Calculate the gap norm.
    float gapNorm = frobeniusNormByJ(23);
    if (gapNorm != 0) {
        
        //Calculate the scaling value.
        float value = 0.0;
		int j=0;
        for (j=0; j<23; j++)
            value += frobeniusNormByJ(j);
        value /= gapNorm*23.0;
        value *= gapScaleParameter;
        
        //Apply the scaling value to all of the gaps.
        j=23;
        for (int i=0; i<cMi; i++)
            for (int k=0; k<cMk; k++)
                matrix[i][j][columnList[k]] *= value;
    }
}


///////////////////////////////////////////////////////////////////////////

// Constructor
// extra "int binary" parameter uses the binary version
SequenceQR::SequenceQR(SequenceAlignment *alignment, float identityCutoff, int preserveCount, int performGapScaling, float gapScaleParameter, int binary) {

    this->alignment = alignment;
    this->identityCutoff = identityCutoff;
    this->preserveCount = preserveCount;
    this->performGapScaling = performGapScaling;
    this->gapScaleParameter = gapScaleParameter;
    cMi = alignment->getLength();
    cMj = 2;
    cMk = alignment->getSequenceCount();
    this->binary = 1;
	
    //Create a matrix containing the data representing this alignment.
    matrix = new float**[cMi];
    for (int i=0; i<cMi; i++) {
    
        matrix[i] = new float*[cMj];
	matrix[i][0] = new float[cMk];
	matrix[i][1] = new float[cMk];
        //for (int j=0; j<cMj; j++) {
        
            //matrix[i][j] = new float[cMk];
            for (int k=0; k<cMk; k++) {
            
                //Get the sequence we are working with.
                AlignedSequence* sequence = alignment->getSequence(k);
                
                //Fill in the value for the matrix element.
		int result = sequence->getAlphabet()->isGap(sequence->getSymbol(i));
		matrix[i][0][k] = result?0:1;
		matrix[i][1][k] = result?1:0;
	
            }
        //}
    }
    
    //Fill in the initial sequence ordering list.    
    columnList = new int[cMk];
    for (int k=0; k<cMk; k++) {
        columnList[k] = k;
    }
    
    //Scale the gap data.
    if (performGapScaling) {
        scaleGapData(1);
    }
    

    /*
    //Print out the matrix.
    for (int k=0; k<cMk; k++)
    {
        printf("Sequence %d\n", k);
        for (int i=0; i<cMi; i++)
        {
            for (int j=0; j<24; j++)
            {
                printf("%1.2f ", fullMatrix[i][j][columnList[k]]);
            }
            printf("\n");
        }
        printf("\n");
    }
    */
    
}


// scaleGapData
//   Scale the gap matrix elements to appropriate values so that
//   the QR algorithm is not biased towards or against the gaps.
//   scale*fNorm(G) = fNorm(X) + fNorm(Y) + fNorm(Z)
//
// params: jval: size of the alphabet to check against
//
void SequenceQR::scaleGapData(int jval) {

    //Calculate the gap norm.
    float gapNorm = frobeniusNormByJ(jval);
    if (gapNorm != 0) {
        
        //Calculate the scaling value.
        float value = 0.0;
		int j=0;
        for (j=0; j<jval; j++)
            value += frobeniusNormByJ(j);
        value /= gapNorm*((float)jval);
        value *= gapScaleParameter;
        
        //Apply the scaling value to all of the gaps.
        j=jval;
        for (int i=0; i<cMi; i++)
            for (int k=0; k<cMk; k++)
                matrix[i][j][columnList[k]] *= value;
    }
}

void SequenceQR::printColumns() {
    //Print out the matrix.
    for (int k=0; k<cMk; k++)
    {
        printf("Sequence %d\n", k);
        for (int i=0; i<cMi; i++)
        {
            for (int j=0; j<cMj; j++)
            {
                printf("%1.2f ", matrix[i][j][columnList[k]]);
            }
            printf("\n");
        }
        printf("\n");
    }
}


