# Free energy and MFPT calculation for milestoning

import sys
import numpy
import shutil
import numpy as np
from numpy import linalg
import math
import scipy
from scipy import optimize

#   
#   read atominformation
#   
merge_list=['.']
stop_step=['1000000']
startimage=0
endimage=19
maximage=endimage-startimage+1
target_cell1=9
target_cell2=10
remove_cell_array=[]
len_remove_cell=len(remove_cell_array)
filemat=np.zeros(shape=(maximage,maximage),dtype=float)
Tmat=np.zeros(maximage,dtype=float)
ratemat=np.zeros(shape=(maximage,maximage),dtype=float)
coe=np.zeros(shape=(maximage+1-len_remove_cell,maximage-len_remove_cell),dtype=float)
res=np.zeros(maximage+1-len_remove_cell,dtype=float)
popu=np.zeros(maximage-len_remove_cell,dtype=float)
popu1=np.zeros(maximage-len_remove_cell,dtype=float)
F_energy=np.zeros(maximage-len_remove_cell,dtype=float)
population_file='population.dat'
free_file='Free_energy.dat'
main_MFPT_file='main_MFPTs.dat'

merge_num=0
for trj_dir in merge_list:
    for i in range(maximage):
        f=open ("{0}/count_{1}.dat".format(trj_dir, i),"r")
        readflag=0
        for rl in f:
            if readflag == 2:
                break
            if readflag == 1:
                temp=rl.split()
                for j in range(maximage):
                    filemat[i][j]=filemat[i][j]+float(temp[j])
                ratemat[i]=filemat[i]/Tmat[i]
                readflag=2
            if '## time: {0}'.format(stop_step[merge_num]) in rl:
                tl=rl.split()
                Tmat[i]=Tmat[i]+float(tl[5])
                readflag=1
        f.close()
    merge_num+=1
#print ratemat
k=0
l=0
for i in range(maximage):
    if i in remove_cell_array:
        continue
    l=0
    for j in range(maximage):
        if j in remove_cell_array:
            continue
        coe[k][l]=ratemat[j][i]
        l=l+1
    temp=0
    for j in range(maximage):
        if j in remove_cell_array:
            continue
        temp=temp-ratemat[i][j]
    coe[k][k]=coe[k][k]+temp
    res[k]=0
    k=k+1
k=0
for j in range(maximage):
    if j in remove_cell_array:
        continue
    coe[maximage-len_remove_cell][k]=1
    k=k+1
res[maximage-len_remove_cell]=1

def __residual1(popu, coe, res): 
    return res - np.dot(coe, abs(popu))
p_cov1=np.zeros(maximage-len_remove_cell)
popu, p_cov1 = scipy.optimize.leastsq(__residual1, popu, args=(coe, res))

popu1=linalg.lstsq(coe,res)[0]

print popu
print popu1

np.savetxt(population_file,popu)

print "Fn:"
for i in range(maximage-len_remove_cell):
    F_energy[i]=-1*0.593*math.log(popu[i])
    print "{0} {1}".format(i, F_energy[i])
np.savetxt(free_file, F_energy)

### begin mean first passage time calculation ###

Nijmat=np.zeros(shape=(maximage,maximage*maximage),dtype=float)
Rmat=np.zeros(shape=(maximage,maximage),dtype=float)

merge_num=0
for trj_dir in merge_list:
    for i in range(maximage):
        f=open ("{0}/Nij_{1}.dat".format(trj_dir, i),"r")
        fT=open ("{0}/RTi_{1}.dat".format(trj_dir, i),"r")
        readflag=0
        for rl in f:
            if readflag == 1:
                break	
            if 'TIME{0}'.format(stop_step[merge_num]) in rl:
                tl=rl.split()
                tl.pop(0)
                for j in range(maximage*maximage):
                    temp1=float(tl[j])
                    Nijmat[i][j]=Nijmat[i][j]+temp1
                readflag=1
        readflag=0
        for rtl in fT:
            if readflag == 1:
                break
            if 'TIME{0}'.format(stop_step[merge_num]) in rtl:
                tl=rtl.split()
                tl.pop(0)
                for j in range(maximage):
                    Rmat[i][j]=Rmat[i][j]+float(tl[j])
                readflag=1
        f.close()
        fT.close()
    merge_num+=1
# initiate the data structure for milestones
num_milestone=0
for i in range(maximage):
    if i in remove_cell_array:
        continue
    for j in range(i+1, maximage):
        if j in remove_cell_array:
            continue
        if (Rmat[i][j]!=0) or (Rmat[j][i]!=0):
            num_milestone=num_milestone+1

print "number of milestones:"
print num_milestone
milestone_dic=np.zeros(shape=(num_milestone,2),dtype=int)
real_milestone_dic=np.zeros(shape=(num_milestone,2),dtype=int)
qij=np.zeros(shape=(num_milestone, num_milestone),dtype=float)
Nij_ave=np.zeros(shape=(num_milestone, num_milestone),dtype=float)
RTi_ave=np.zeros(num_milestone,dtype=float)
Tms_mat=np.zeros(maximage-len_remove_cell,dtype=float)
qres=np.zeros(num_milestone-1,dtype=float)
qcoe=np.zeros(shape=(num_milestone-1, num_milestone-1), dtype=float)
tau_MFPT=np.zeros(num_milestone-1,dtype=float)
tau_MFPT1=np.zeros(num_milestone-1,dtype=float)
main_MFPT=np.zeros(maximage-1-len_remove_cell,dtype=float)

i=0
j_real=0
k_real=0
for j in range(maximage-len_remove_cell):
    while j_real in remove_cell_array:
        j_real=j_real+1
    k_real=j_real+1
    for k in range(j+1,maximage-len_remove_cell):
        while k_real in remove_cell_array:
            k_real=k_real+1
        if (Rmat[j_real][k_real]!=0) or (Rmat[k_real][j_real]!=0):
            milestone_dic[i][0]=j
            milestone_dic[i][1]=k
            real_milestone_dic[i][0]=j_real
            real_milestone_dic[i][1]=k_real
            if (j_real==target_cell1) and (k_real==target_cell2) :
                print "target_milestone is {0}".format(i)
                target_milestone=i
            i=i+1
        k_real=k_real+1
    j_real=j_real+1
print "milestone_dic:"
print milestone_dic
print "real_milestone_dic:"
print real_milestone_dic
print "check j_real and k_real {0} == {1}".format(j_real-1, k_real-1)
#initiate finish

# calculate stay time
i=0
for j in range(maximage):
    if j in remove_cell_array:
        continue
    for k in range(maximage):
        if k in remove_cell_array:
            continue
        Tms_mat[i]=Tms_mat[i]+Rmat[j][k]
    print "check: image{0}, {1} compared with {2}".format(i, Tms_mat[i], Tmat[j])
    #Tms_mat[i]=Tmat[j]
    i=i+1
if i != (maximage - len_remove_cell) :
    print "check {} != {} Error!".format(i, maximage-len_remove_cell)
# calculate the q matrix
Nij_flag=0

for i in range(num_milestone):
    temp=0
    temp=temp+popu[milestone_dic[i][0]]*(Rmat[real_milestone_dic[i][0]][real_milestone_dic[i][1]]/Tms_mat[milestone_dic[i][0]])
    temp=temp+popu[milestone_dic[i][1]]*(Rmat[real_milestone_dic[i][1]][real_milestone_dic[i][0]]/Tms_mat[milestone_dic[i][1]])
    RTi_ave[i]=RTi_ave[i]+temp
    
    # check RTi matrix
    # right now, not check
    # check RTi finished
    
    for j in range(num_milestone):
        if j == i :
            continue
        if (milestone_dic[i][0]==milestone_dic[j][1]):
            if (real_milestone_dic[i][0]!=real_milestone_dic[j][1]):
                print "Error in matching real_milestone_dic!"
                break
            Nij_flag=Nij_flag+1
            alpha = milestone_dic[i][0]
            alpha_real = real_milestone_dic[i][0]
            mile1 = real_milestone_dic[i][1]
            mile2 = real_milestone_dic[j][0]
        if (milestone_dic[i][0]==milestone_dic[j][0]):
            if (real_milestone_dic[i][0]!=real_milestone_dic[j][0]):
                print "Error in matching real_milestone_dic!"
                break
            Nij_flag=Nij_flag+1
            alpha = milestone_dic[i][0]
            alpha_real = real_milestone_dic[i][0]
            mile1 = real_milestone_dic[i][1]
            mile2 = real_milestone_dic[j][1]
        if (milestone_dic[i][1]==milestone_dic[j][0]): 
            if (real_milestone_dic[i][1]!=real_milestone_dic[j][0]):
                print "Error in matching real_milestone_dic!"
                break
            Nij_flag=Nij_flag+1
            alpha = milestone_dic[i][1]
            alpha_real = real_milestone_dic[i][1]
            mile1 = real_milestone_dic[i][0]
            mile2 = real_milestone_dic[j][1]
        if (milestone_dic[i][1]==milestone_dic[j][1]):
            if (real_milestone_dic[i][1]!=real_milestone_dic[j][1]):
                print "Error in matching real_milestone_dic!"
                break
            Nij_flag=Nij_flag+1
            alpha = milestone_dic[i][1]
            alpha_real = real_milestone_dic[i][1]
            mile1 = real_milestone_dic[i][0]
            mile2 = real_milestone_dic[j][0]
        if (Nij_flag > 1):
            print "Error as Nij_flag>1"
            break
        if (Nij_flag == 0):
            continue
        print "mile1:{0}, mile2:{1}".format(mile1,mile2)
        print "alpha={}; alpha_real={}".format(alpha, alpha_real)
        Nij_flag=0
        Nij_0_flag=0
        temp=0
        if Nijmat[alpha_real][mile1*maximage+mile2] == 0:
            Nij_0_flag = 1
            print "warning: empty transition between {0} and {1} in image {2} found".format(mile1,mile2,alpha_real)
        else :
            temp=temp+popu[alpha]*(Nijmat[alpha_real][mile1*maximage+mile2]/Tms_mat[alpha])
        Nij_ave[i][j]=Nij_ave[i][j]+temp
        if Nij_0_flag == 1 :
            continue


#obtain qij
for i in range(num_milestone) :
    for j in range(num_milestone) :
        if i == j :
            continue
        qij[i][j] = Nij_ave[i][j]/RTi_ave[i]

for i in range(num_milestone) :
    temp=0
    for j in range(num_milestone):
        if i == j :
            continue
        temp=temp-qij[i][j]
    qij[i][i] = temp


# prepare coefficiency matrix
k=0
l=0
for i in range(num_milestone) :
    if i==target_milestone :
        continue
    qres[k] = -1
    l=0
    for j in range(num_milestone) :
        if j==target_milestone :
            continue
        qcoe[k][l]=qij[i][j]
        l=l+1
    k=k+1


# solve equations:
def __residual2(tau_MFPT, qcoe, qres): 
    return qres - np.dot(qcoe, abs(tau_MFPT))
p_cov2=np.zeros(num_milestone-1)
tau_MFPT, p_cov2 = scipy.optimize.leastsq(__residual2, tau_MFPT, args=(qcoe, qres))
tau_MFPT1 = linalg.lstsq(qcoe, qres)[0]


# check 
test=np.zeros(num_milestone)
for i in range(num_milestone) :
    test=test + qij.T[i]
#print test
print "qcoe dot tau_MFPT:"
print np.dot(qcoe, tau_MFPT)


k=0
j=0
indexk=0
for i in range(num_milestone) :
    if i==target_milestone: 
        indexk=indexk+1
        continue
    print real_milestone_dic[i]
    print indexk
    while indexk+1 in remove_cell_array:
        indexk=indexk+1
    if (indexk in remove_cell_array) and not((indexk+1) in remove_cell_array):
        indexk=indexk+1
    if (real_milestone_dic[i][0]==indexk) and (real_milestone_dic[i][1]==indexk+1) :
        main_MFPT[k]=tau_MFPT[j]
        print "main milestone {0},{1} (original {2},{3}) 's MFPT is {4}".format(indexk, indexk+1, real_milestone_dic[i][0], real_milestone_dic[i][1], main_MFPT[k])
        k=k+1
        indexk=indexk+1
    j=j+1
print k
for i in range(num_milestone) :
    if ((real_milestone_dic[i][0]==0) and (real_milestone_dic[i][1]==maximage-1) and not((maximage-1) in remove_cell_array)) :
        main_MFPT[k]=tau_MFPT[i]
        print "main milestone {0},{1} (original {2},{3}) 's MFPT is {4}".format(milestone_dic[i][0], milestone_dic[i][1], real_milestone_dic[i][0], real_milestone_dic[i][1], main_MFPT[k])
        k=k+1

print "number of total main milestones: {0}, compare with {1}".format(k+1, maximage-len_remove_cell)
np.savetxt(main_MFPT_file,main_MFPT)

