import sys
import readTop
import geom
import top_db

def is_idx(term):
 if term=='': return False
 for i in term:
  if not i in \
  ['0','1','2','3','4','5','6','7','8','9',' ']:
   return False
 return True

#all_system : contains all stuff, pro, sol and lip....
#sys_list: a pair [name of obj for subsystem, [list of resnames in this subsystem]]
#the residues with proper resname will be appended to each of subsystem
def seperate_system(all_system, sys_list):
 disgard_list=[]
 reslist = all_system.reslist
 for i in reslist:
  found=False
  resnm=i[2]
  
  for subsys in sys_list:
   if resnm in subsys[1]:
    found=True
    break
  if not found:
   if i[2] in disgard_list: continue
   print i[2],all_system.atoms[i[0]][4],'is not chosen for any class of segments'
   print 'We have following groups of residue names for different classes:'
   for kk,k in enumerate(sys_list):
    print kk,':',k[1]
   print 'Please choose one or this type of residue is all disgarded'
   co=sys.stdin.readline()
   if not is_idx(co[:-1]):
    print i[2],'disgarded'
    disgard_list.append(i[2]) 
    continue
   opt=int(co)
   if opt<0 or opt>=len(sys_list): 
    print i[2],'disgarded'
    disgard_list.append(i[2])
    continue
   subsys = sys_list[opt]
   subsys[1].append(i[2])

  for j in range(i[0],i[0]+i[1]):
   subsys[0].atoms.append(all_system.atoms[j])

 for subsys in sys_list:
  if len(subsys[0].atoms)>0:
   subsys[0].build_atomnms()
   subsys[0].res_list()
   subsys[0].build_xp()


#def show_temp_atom(x,anm):
# print_pdb_line(99999, anm, 'POPC', 'O',\
#                        3, x[0],x[1],x[2],'O66')

def print_pdb_atom(atom):
 print_pdb_line(atom[0],atom[1],atom[2],atom[3],atom[4],\
                atom[5],atom[6],atom[7],atom[8])

def pdb_atom_string(atom):
 return pdb_string(atom[0],atom[1],atom[2],atom[3],atom[4],\
                atom[5],atom[6],atom[7],atom[8])

def pdb_string(aid, anm, rnm, cnm, rid, x, y, z, snm):
 if len(anm)<3: 
  anm1=anm+' '
 else:
  anm1=anm
 if len(rnm)<4: 
  rnm1=rnm+' '
 else:
  rnm1=rnm
 if aid>99999: aid=99999
 return "ATOM  %5d%5s %4s%1s"%(aid, anm1,rnm1,cnm)+\
       '%4d    %8.3f%8.3f%8.3f'%(rid,x,y,z)+\
       '%18s%3s'%(' ',snm)

def print_pdb_line(aid, anm, rnm, cnm, rid, x, y, z, snm):
 if len(anm)<3: 
  anm1=anm+' '
 else:
  anm1=anm
 if len(rnm)<4: 
  rnm1=rnm+' '
 else:
  rnm1=rnm
 if aid>99999: aid=99999
 print "ATOM  %5d%5s %4s%1s"%(aid, anm1,rnm1,cnm)+\
       '%4d    %8.3f%8.3f%8.3f'%(rid,x,y,z)+\
       '%18s%3s'%(' ',snm)

def n_res(atoms, st_idx, end_idx):
#return number of residues in atom idx range st_idx<=  <end_idx
 l_rid=-1
 l_nm='none'
 nres=0
 for i in range(st_idx, end_idx):
  if atoms[i][4]!=l_rid or atoms[i][2]!=l_nm:
   nres+=1
   l_rid=atoms[i][4]
   l_nm=atoms[i][2]
 return nres

def pdb_line(rl,autoAssign=False, ass_idx=0):
    """
    Read a line and parse it into PDB components
    [0]: atom number
    [1]: atom name
    [2]: residu name
    [3]: chain name
    [4]: resid
    [5-7]: x,y,z
    [8]: segname
    """
    atominfo=[]
#atom indice in PDB will not be used if autoAssign is on
#instead, all atom indices start at 0
    if (rl[0:4]!='ATOM'):
        return 'None'
    if not autoAssign:
     if not is_idx(rl[6:11]):
       atominfo.append(99999)
     else:
       atominfo.append(int(rl[6:11]))
    else:
     atominfo.append(ass_idx)

    atominfo.append(rl[12:17].strip())
    atominfo.append(rl[17:21].strip())
    atominfo.append(rl[21:22])
    atominfo.append(int(rl[22:26]))
    atominfo.append(float(rl[30:38]))
    atominfo.append(float(rl[38:46]))
    atominfo.append(float(rl[46:54]))
    if len(rl)>=75:
     kl=len(rl)
     if kl>76: kl=76 
     atominfo.append(rl[72:kl].strip())
    else:
     atominfo.append('')
    return atominfo

def load_pdb_xyz(atom):
 x=[]
 x.append(atom[5])
 x.append(atom[6])
 x.append(atom[7])
 return x

#load both atom nms and coords in certain range
def load_x_list_range(atoms, st_idx, en_idx):
 x_list=[]
 nm_list=[]
 #print st_idx, en_idx
 for i in range(st_idx, en_idx):
  #print i,atoms[i]
  x_list.append([atoms[i][5], atoms[i][6], atoms[i][7]])
  nm_list.append(atoms[i][1])
 return x_list, nm_list

#add new atom according 
#atom info in atoms
#only atom in range(st_idx, en_idx) will be checked
#return one coord
#add_hv_nmlist ['A','B','C'] three atom with name 'A', 'B' 'C'
#in range (st_idx, en_idx) of atoms will be used to
#construct the new atom
def cal_new_atom(atoms, st_idx, en_idx,\
                           add_hv_nmlist, add_hv_b, add_hv_a, add_hv_d):
 x_list, nm_list =  load_x_list_range(atoms, st_idx, en_idx)
 x_i=[]
 #print len(x_list), len(nm_list)
 for ni in add_hv_nmlist:
  x_idx= nm_list.index(ni)
  x_i.append(x_list[x_idx])
 re= geom.calxyz(x_i[0], x_i[1], x_i[2],\
            add_hv_b, add_hv_a, add_hv_d)
 return re

#add H atoms
#return list of coords
def add_new_H(atoms, st_idx, en_idx,\
                            add_h_type, add_h_nmlist):
 x_list, nm_list =  load_x_list_range(atoms, st_idx, en_idx)
 x_i=[]
 for ni in add_h_nmlist:
  x_idx= nm_list.index(ni)
  x_i.append(x_list[x_idx])
 re=[]
 if add_h_type==1:
  re.append(geom.add1H_pyra(x_i[0], x_i[1], x_i[2], x_i[3]))
 elif add_h_type==2:
  re.append(geom.calxyz(x_i[0], x_i[1], x_i[2],\
            1.0, 110.0, -120.0))
  re.append(geom.calxyz(x_i[0], x_i[1], x_i[2],\
            1.0, 110.0, 120.0))
 return re
#==========================
#load atoms including structral infos and coordinates
#from a file named as fnm
#return a list of structures, containing all infos for each atom
def load_PDB_atoms(fnm):
 atoms=[]
 f=open(fnm,"r")
 a_idx=0
 for rl in f:
  if 'END' in rl: break
  if 'ATOM' in rl or 'HETATM' in rl:
   atoms.append(pdb_line(rl, True, a_idx))
   a_idx+=1
 f.close()
 return atoms


def load_gaussian_atoms(fnm):
 atoms=[]
 f=open(fnm,'r')
 a_idx=0
 cnt=0
 for rl in f:
  if 'N-N' in rl:
   cnt+=1
   if cnt>1: 
    break
 if cnt<=1:
  print 'Wrong Gaussian output format'
  sys.exit()
 atxt=''
 for rl in f:
  atxt+=rl[:-1].strip()
  if '@' in rl: break
 rtxt=atxt.replace('\\','|')
 st=-1
 en=-1
 cnt=0
 for i in range(len(rtxt)-1):
  if rtxt[i]=='|' and rtxt[i+1]=='|':
   cnt+=1
   if cnt==3:
    st=i+2
   if cnt==4:
    en=i

 if st<0 or en<0: 
  print 'Wrong Guassian output format'
  sys.exit()

 ntxt=rtxt[st:en]
 srl=ntxt.split('|')
 for i in range(1,len(srl)):
  ss=srl[i].split(',')
  atoms.append([a_idx, ss[0], 'NON','X', 1, \
  float(ss[1]), float(ss[2]), float(ss[3]), 'X' ])
  a_idx+=1

 f.close()
 return atoms

class PDBdata:
 def __init__(self, fnm=None, gauss=False):
  if fnm is None:
   self.atoms=[]
   return
  if gauss:
   self.atoms=load_gaussian_atoms(fnm)
  else:
   self.atoms=load_PDB_atoms(fnm)
  self.build_atomnms()

  self.res_list()
  self.build_xp()

 def show_pdb(self,fnm=None,exclu=[]):
  atoms=self.atoms
  if fnm!=None:
   f=open(fnm,'w')
  for i in atoms:
   if not i[1][0] in exclu:
    if fnm!=None:
     f.write(pdb_atom_string(i)+'\n') 
    else:
     print_pdb_atom(i)
  
  if fnm!=None:
   f.write('END\n')
   f.close()
  else:
   print 'END'

 def atoms_update(self, natoms):
  del self.reslist
  del self.xp
  del self.atomnms
  del self.atoms
  self.atoms=natoms
  
  self.atomnms=[]
  for i in self.atoms:
   self.atomnms.append(i[1])
  self.res_list()
  self.build_xp()


 def center(self):
  t=[0.,0.,0.]
  for x in self.xp:
   for dim,xi in enumerate(x):
    t[dim]+=  xi
  n=len(self.xp)
  for dim in range(3):
   t[dim]/=float(n)

  del self.xp
  for i in self.atoms:
   i[5]-=t[0]
   i[6]-=t[1]
   i[7]-=t[2]
  self.build_xp()


 def update_conn(self):
  readTop.construct_conn(self, 'pdb', self.atoms)
  #print self.conn
  readTop.construct_atomSig(self)
  readTop.sort_sig(self.atomSigs)
  readTop.construct_uniq(self)


 def build_atomnms(self):
  self.atomnms=[]
  for i in self.atoms:
   self.atomnms.append(i[1])

 def build_xp(self):
  self.xp=[]
  for i in self.atoms:
   self.xp.append(load_pdb_xyz(i))

 def integrity_check(self):
  readTop.construct_conn(self,'pdb',self.atoms)
  readTop.construct_atomSig(self,None,False)
  if len(self.atomSigs[0])<len(self.atomnms)*3:
   print 'Warning! Some part of molecule is disconnected from the other.'

#construct a residue list[[st_idx, len, name]]
 def res_list(self):
  self.reslist=[]
  atoms=self.atoms
  l_idx=-1
  for i,ai in enumerate(atoms):
   if ai[4]!=l_idx:
    if l_idx!=-1:
     self.reslist.append([st_idx, res_len, atoms[st_idx][2]])
    l_idx=ai[4]
    st_idx=i
    res_len=0
   res_len+=1
  self.reslist.append([st_idx, res_len, atoms[st_idx][2]])

#assuming xp is already defined
#this is only by reference
 def extract(self,list_nm):
  ext=[]
  for i in list_nm:
   i_idx= self.atomnms.index(i)
   ext.append(self.xp[i_idx])

  return ext

 def ext_xp_byidx(self, st, en):
  if st<0 or en>len(self.atoms):
   print 'Wrong idx for selecting coordinates'
   sys.exit()
  xp=self.xp
  t_xp=[]
  for i in range(st, en):
   t_xp.append([xp[i][0], xp[i][1], xp[i][2]])

  return t_xp

 def extract_atompair_byres(self, list_pair):
  ext=[]
  atomnms=self.atomnms
  for i in self.reslist:
   t_nm=[]
   for j in range(i[0],i[0]+i[1]):
    t_nm.append(atomnms[j])
   for k in list_pair:
    if k[0] in t_nm and k[1] in t_nm:
     k0 = t_nm.index(k[0])+ i[0]
     k1 = t_nm.index(k[1])+ i[0]
     ext.append([[k0, self.xp[k0]],[k1, self.xp[k1]]])
  return ext

 def extract_bypart(self, list_nm, st_idx, seglen):
  ext=[]
  t_list=[self.atomnms[i] for i in range(st_idx, st_idx+seglen)]
#  print t_list
#  print list_nm

  for i in list_nm:
   if not i in t_list: continue
   i_idx=t_list.index(i) + st_idx
   ext.append(self.xp[i_idx])
  return ext

#return number of differet chains
#and staring internal idx for each chain
 def get_chains(self,verbose=True):
  c_st=[]
  n_c=0
  l_c='$$'
  for i in self.reslist:
   st_idx=i[0]
   c_c = self.atoms[st_idx][3]
   #print c_c
   if l_c!=c_c:
    
    l_c=c_c
    n_c+=1
    if verbose:
     print 'Chain',n_c,'starts at',\
     self.atoms[st_idx][8]+':'+self.atoms[st_idx][2]+str(self.atoms[st_idx][4])
    c_st.append(st_idx)
  c_st.append(len(self.atoms))
  return c_st

#get both chain name and chain st_idx
#return a dict:  'chain name': [st_idx, end_idx, st_residx, end_residx]
 def get_chain_info(self):
  c_st=self.get_chains(False)
  idx_serie=[]
  chain_name=[]
  res_id=0
  for i in range(len(c_st)-1):
   e_id=res_id + n_res(self.atoms, c_st[i], c_st[i+1])
   idx_serie.append([c_st[i], c_st[i+1],res_id, e_id])
   res_id = e_id
   chain_name.append(self.atoms[c_st[i]][3])
  return dict([(i,j) for i,j in zip(chain_name, idx_serie)])
   
 def get_segments(self,verbose=True):
  c_st=[]
  n_c=0
  l_c='$$'
  for i in self.reslist:
   st_idx=i[0]
   c_c = self.atoms[st_idx][8]
   #print c_c
   if l_c!=c_c:
    
    l_c=c_c
    n_c+=1
    if verbose:
     print 'Chain',n_c,'starts at',\
     self.atoms[st_idx][8]+':'+self.atoms[st_idx][2]+str(self.atoms[st_idx][4])
    c_st.append(st_idx)
  c_st.append(len(self.atoms))
  return c_st

 def build_struct(self):
  if len(self.atoms)==0: return
  self.build_atomnms()
  self.res_list()
  self.build_xp()

 def tcharge(self):
  ch_res=top_db.charge_res_db()
  reslist = self.reslist
  tc=0
  for i in reslist:
   if i[2] in ch_res:
    tc+=ch_res[i[2]]
  return tc

 def volume(self):
  box = self.box()
  return box[0]*box[1]*box[2]

 def box(self):
  xp=self.xp
  max, min = geom.find_max_min(xp)
  box_c=[]
  for i,j in zip(max, min):
   box_c.append(i-j)
  return box_c

#end of PDBdata class

#CGing AA system (only with small molecules) to CG
#return a obj of PDBdata class
def coarse_graining_system_small(s_sys,seg_init='S',database=top_db.aa2cg_db()):
 sys_cg=PDBdata()
 aa2cg, aa2cg_list = database
 #print aa2cg_list
 reslist = s_sys.reslist
 l_res = ''
 cnt=0
 resid=1
 
 for i in reslist:
  #print i
  if not i[2] in aa2cg_list:
   print 'Warning!',i[2],'starting at ',i[0],'is disgarded'
   continue
  resid+=1
#this part make sure different component in separate segments 
  #if l_res!=i[2]:
  # l_res=i[2]
  # resid=1
  # cnt+=1
  if resid>9999:
   resid=1
   cnt+=1

  #print cnt, int(cnt/36)
  
  segnm = seg_init+ readTop.cod2char(int(cnt/36))+\
          readTop.cod2char(cnt%36)   

  idx = aa2cg_list.index(i[2])
  aa2cg_rule = aa2cg[idx]
  #print i, aa2cg_rule
  coarse_graining_residue(s_sys,i[0],i[1],sys_cg, aa2cg_rule,resid,segnm)
  
 if len(sys_cg.atoms)>0:
  sys_cg.build_atomnms()
  sys_cg.res_list()
  sys_cg.build_xp()

 return sys_cg


#CG AA to CG residue according to rule in aa2cg_rule
def  coarse_graining_residue(s_sys,st,len_s,sys_cg,\
                             aa2cg_rule,resid,segnm,chain='S'):
 cg_resnm = aa2cg_rule[1]
 cg_list = aa2cg_rule[2]
 aid = len(sys_cg.atoms)
 for i in cg_list:
  #print i
  cg_anm = i[0]
  aa_xp_list = s_sys.extract_bypart(i[1],st,len_s)

  #print aa_xp_list
  if len(aa_xp_list)==0:
   print 'Warning!',s_sys.atoms[st][2],'has no atoms for',cg_anm,'in',cg_resnm
   continue
  aid+=1
  xp_mc = geom.centroid(aa_xp_list)
  
  sys_cg.atoms.append([aid, cg_anm, cg_resnm, chain, resid,\
                 xp_mc[0],xp_mc[1],xp_mc[2], segnm])
