import sys

import top_db
import connectivity

class PSFdata:

 def __init__(self, fnm, modpairff=[],nnbmodpair=False): 
#memebr variable init
  self.atoms=[]
  self.bonds=[]
  self.bond_params=[]
  self.angles=[]
  self.angle_params=[]
  self.dihes=[]
  self.dihe_params=[]
  self.imprs=[]
  self.impr_params=[]
  self.pairs=[]
  self.pair_params=[]
  self.__anm=[]
  self.exclusions=[]
  self.typ=[]
  self.modpff= modpairff
  self.nnbmodpair = nnbmodpair

#for loading NNB
  self.__pairs = []
  self.__pair_params = []
  #self.__firstNNB=True


  f=open(fnm, 'r')
  #get all info for loading each section of psf
  #given a section such as '!NATOM'
  #lr['!NATOM'][0] return a member function of PSFdata class
  #it can be used like below:
  #lr['!NATOM'][0] (readline_string, lr['!NATOM'][1] parameters for this function)
  lr = top_db.load_psf_rule(self)
 
  #l_sec='!NONE':
  c_lr = None
  all_sec= lr.keys()
  for rl in f:
   srl = rl[:-1].split()
   sl=5
   if len(srl)>1:
    if len(srl[1])<5: sl=len(srl[1])
    #print srl[1][:sl]
    if srl[1][:sl] in lr.keys():
    #l_sec =
     #print srl[1]
     c_lr = lr[srl[1][:sl]]
     continue
    elif len(srl)>2:
	#in case of '!NGRP'
     if srl[2] in lr.keys():
      c_lr = lr[srl[2]]
      continue
   #print rl,
   if '!' in rl: 
    rl=rl[:rl.index('!')]  
   else:
    rl=rl[:-1]
   if len(rl.split())<1: continue
   if c_lr is None: continue
   if c_lr[0]!=None:
    c_lr[0](rl, c_lr[1])



  f.close()

#post process nnb
  #if self.nnbmodpair:
  nnbnum = self.__pairs[-1]
  if len(self.__pairs)!=len(self.atoms)+ nnbnum:
   print 'Incorrect NNB section'
   sys.exit()
  parta = self.__pairs[:nnbnum]
  partb = self.__pairs[nnbnum:]

  if nnbmodpair:
   __newpair = self.pairs
   __newparam = self.pair_params
  else:
   __newpair = self.exclusions
   __newparam = []

  idx_last=0
  for i in range(len(self.atoms)):
   idx = partb[i]
   for j in range(idx_last, idx):
    #i zero-based, parta[] 1-based
    __newpair.append([i, parta[j]-1]) 
    __newparam.append([])
   idx_last=idx

  #print self.pairs
#get implicit 14pair
  conn = connectivity.connect(None, False,None,None,self)
 #need to remove redundancy
  for i in conn.c14:
   #print i
   if (not [i[0],i[-1]] in self.pairs) and\
      (not [i[-1],i[0]] in self.pairs):
    a=min([i[0],i[-1]])
    b=max([i[0],i[-1]])
    self.pairs.append([a, b])
    self.pair_params.append([])

 # print self.pairs
 #need to exclude pairs for small ring 
  for i in conn.c12+conn.c13+self.exclusions:
   idx=-1
   if [i[0],i[-1]] in self.pairs:
    idx=self.pairs.index([i[0],i[-1]])
   if [i[-1],i[0]] in self.pairs:
    idx=self.pairs.index([i[-1],i[0]])

   if idx>=0:
    self.pairs.pop(idx)
    self.pair_params.pop(idx)

 #exlcusion list
  self.exclusionList=[[] for i in range(len(self.atoms))]
  #print self.exclusions 
  for i in conn.c12+conn.c13+self.exclusions+self.pairs:
   #print i
   a=min([i[0],i[-1]])
   b=max([i[0],i[-1]])
   if b in self.exclusionList[a]: continue
   self.exclusionList[a].append(b)

  for i  in self.exclusionList: i.sort()

#load modified pair
 def load_modp(self,rl, params):
  ff = params[0]
  srl=rl.split()
  if len(srl)==0: return
  self.pairs.append([int(srl[0])-1, int(srl[1])-1])
  if ff==[]:
   self.pair_params.append([])
  else:
   self.pair_params.append(ff[int(srl[2])-1][:2]) 

 def load_nnb(self, rl, params):
  srl=rl.split()
  #two parts of NNB packed in 1D array
  for i in srl:
   self.__pairs.append(int(i)) #preserve 1-based 
  return   
 

#make unqie atomnms by adding suffex _x
 def unique_anm(self):
  __newnm=[]
  __newcnt=[]
  for i in self.atoms:
   if not i[3] in __newnm:
    __newnm.append(i[3])
    __newcnt.append(0)
   idx = __newnm.index(i[3])
   __newcnt[idx]+=1
   i[3]+='_'+str(__newcnt[idx])

 def get_nmtuple(self, idxtuple):
  return [\
   self.atoms[int(str(i))][3]   for i in idxtuple\
  ]
 
 def load_atom(self, rl, params):
 #0: segnm 1: resid 2: resnm 3:atomnm 4:atomtype
 #5:charge 6:mass
  srl=rl.split()
  self.atoms.append([\
 srl[1], int(srl[2]), srl[3],\
 srl[4], srl[5], float(srl[6]), float(srl[7])]\
 )
  self.__anm.append(srl[4])
  self.typ.append(srl[5])

 def load_tuple(self, rl, params):
  tuple_l=params[0]
  typtuple=params[1]
  par=params[2]
  srl=rl.split()
  sl= len(srl)
  if sl%tuple_l!=0:
   print 'Err in reading:',rl
   sys.exit()

  for i in range(0,sl, tuple_l):
   t=[]
   for j in range(i,i+tuple_l):
#all are zero-based
    t.append(int(srl[j])-1)
   typtuple.append(t)
   par.append([])

 def gettypetuple(self, fi):
  return [self.typ[i] for i in fi]


class PSFtypeidx(PSFdata):
 def __init__(self, fnm, nforward=4):
  self.__nf=nforward
  PSFdata.__init__(self, fnm, [])
  self.__indextype()


 def __indextype(self):
  nforward = self.__nf
  suffex = ['x'+chr(ord('A')+i) for i in range(2*nforward+1)]

  #typnm_gen=[]
  self.typ=[]
  last_res=-1
  suf_idx=-1
  for i in self.atoms:
   c_res = i[1]
   if c_res!=last_res:
    suf_idx+=1
    if suf_idx>=len(suffex): suf_idx=0
    last_res= c_res
   i[4]=i[4]+suffex[suf_idx]
   self.typ.append(i[4])
  
