import sys
import top_db
import PDB
#this lib is to identify atom types


def match_list(nlist, cond):
 if len(nlist)!=len(cond): return False
 a=[i for i in cond]
 for i in nlist:
  if i in a:
   idx=a.index(i)
   a.pop(idx)
  elif 'X' in a:
   idx=a.index('X')
   a.pop(idx)
  else:
   return False
 return True

#user input typ
def choose_typ(atyp, mark, atomnms, ffatyp,newtypmapping):
 typnm=[ii[0] for ii in ffatyp]
 #for i in ffatyp: print i[0],i[2]
 print 'Please choose atoms for setting types (multiple indices separated by space):'
 cs = sys.stdin.readline()
 srl= cs[:-1].split()
 t=[]
 for i in srl:
  if i in atomnms: #PDB.is_idx(i):
   t.append(atomnms.index(i))
 
 print 'Please choose types for chosen atoms (separated by space):'

 cs = sys.stdin.readline()
 srl= cs[:-1].split() 
 if len(srl)< len(t):
  print 'Aborted. provided types fewer than expected'
  return sum(mark)
 for i,ti in enumerate(t):
  if ti<0 or ti>=len(mark):
   print 'Ignored. wrong idx', ti
   continue
  #if not srl[i] in typnm:
  # print 'Ignored. wrong type name', srl[i]
  # continue
  if mark[ti]==1:
   print 'warning. atom',ti,'originally type',atyp[ti],\
   'will be changed'
  print 'atom',ti,'now has type',srl[i]
  atyp[ti]=srl[i]
  #if it is a new typ
  if not srl[i] in typnm:
   print srl[i],'is a new type. Please choose an old type for it:'
   sss=''
   while not sss in typnm or sss[0]!=srl[i][0]:
    cs= sys.stdin.readline()
    ss= cs[:-1].split()
    sss= ss[0]
   iss = typnm.index(sss)
   ffatyp.append([srl[i], ffatyp[iss][1], ffatyp[iss][2]])   
   print 'Mapping to',sss,ffatyp[iss][2]
   print
   typnm.append(srl[i])
   newtypmapping.append([srl[i],sss])

  mark[ti]=1

 return sum(mark) 

#input connectivity, atomnm
#output list of guessed atomtype for each atom
#conn[i] contain list of indices of atoms connecting i directly
#type_rule is to ask which force field will be used to assign atype
#atom type from force field
def guess_atomtype(atomnms, conn, type_rule, ffatyp, newtypmapping):
 

 natom=len(atomnms)
 if natom!=len(conn):
  print 'atom names dont match connectivity'
  sys.exit()

 hyd_state= get_hybrid_state(atomnms, conn)
 elewdr_n = get_elewdr_num(atomnms, conn) 
 anm_ini=[i[0] for i in atomnms]
 typnm = [i[0] for i in ffatyp]
 aidx=[i for i in range(natom)]
 atyp=['' for i in range(natom)]
 #for i,j in zip(hyd_state, elewdr_n): print i,j
 nassigned=0
 #0 not assigned yet
 mark = [0 for i in range(natom)]
 

 while nassigned<natom:
  last_assigned=nassigned
#kernel  
  typ2=[]
  for i in range(natom):
   if mark[i]==0:
    c_hyd=hyd_state[i]
    c_elewdr = elewdr_n[i]
    if not c_hyd in type_rule:
     print 'no such rule for', c_hyd
     sys.exit() 
    c_rule = type_rule[c_hyd]
    #priority for each type as order of appearance
    for ru in c_rule:
     r_type=ru[0]
     found=0
     re_typ=''
#based on init names
     if r_type==0:
      nlist = get_nb_nm_list(i, conn, anm_ini)
      for cond in ru[1]:
       if match_list(nlist, cond):
        found=1
        re_typ=ru[2]
        break      
      pass

#based on hybridization state
     if r_type==1:
      nlist = get_nb_nm_list(i, conn, hyd_state)
      for cond in ru[1]:
       if match_list(nlist, cond):
        found=1
        re_typ=ru[2]
        break      
      pass

#based on interactive Q&A
     if r_type==2:
      nlist = get_nb_nm_list(i, conn, hyd_state)
      for cond in ru[1]:
       if match_list(nlist, cond):
   
        typ2.append([i, ru[2]]) #used later
        break  
      pass

#hydbridization + elect withdraw count
     if r_type==3:
      nlist = get_nb_nm_list(i, conn, hyd_state)
      #c_elewdr
      for cond in ru[1]:
       if match_list(nlist, cond):
        ci=conn[i][0]
        c_elewdr_c = elewdr_n[ci]
        if c_elewdr_c<len(ru[2]):
         found=1
         #print c_elewdr_c, ru[2]
         re_typ=ru[2][c_elewdr_c]
         break      
      pass

#based atome type
     if r_type==5:
      pass #not implemented yet

#based on neibor of neibor
     if r_type==4:
      nlist = get_nb_nm_list(i, conn, hyd_state)
      cond = ru[1]
      if match_list(nlist, cond):
       #idx of nb atoms
       idxlist = conn[i]
       m_k=[0 for ii in conn[i]]
       rule_nb=ru[2]
       found_n=0
#all nb cond should be matched
       for j in cond:
        if j in rule_nb:
#in each nb cond, at least one condition should be satisffied
         conds_sub= rule_nb[j]
         found_sb=0       
#any nb atom can be considered, but no single one considered for two condition  
         for ii, nbi in enumerate(idxlist):
          if j==hyd_state[nbi] and m_k[ii]==0:
           nb_l = get_nb_nm_list(nbi, conn, hyd_state)
             
           for cs in conds_sub:
            if match_list(nb_l, cs):
             
             m_k[ii]=1
             found_sb=1
             break
          if found_sb==1:
           break 
         found_n+=found_sb
         
       if found_n==len(rule_nb):
        found=1   
        re_typ=ru[3]           
      pass     



     if found==1: #found for this atom
      atyp[i]=re_typ 
      mark[i]=1
      nassigned+=1
      break
#

  if nassigned<=last_assigned: #no new type assigend
  # to see if there is any unassigned types that can be
  #determined interactively
   print nassigned,'atoms have been assigned:'
   show_assigned(atomnms, mark, atyp, hyd_state)
   if len(typ2)>0:
    for i in typ2:
     print '-------------------------'
     print i[0],atomnms[i[0]],hyd_state[i[0]],' connected (',
     for j in conn[i[0]]: print j,hyd_state[j],
     print '):'
     for j in i[1]:
      if j in typnm:
       print j, ffatyp[typnm.index(j)][2]

    nassigned=choose_typ(atyp, mark, atomnms, ffatyp, newtypmapping)
     
   else:
    for i in range(natom):
     if mark[i]==0:
      print '--------------------------'
      print i,atomnms[i],hyd_state[i],' connected (',
      for j in conn[i]: print j,hyd_state[j],
      print ')'
    nassigned=choose_typ(atyp, mark, atomnms, ffatyp, newtypmapping)


 return atyp


def get_hybrid_state(atomnms, conn):
 t=[]
 ncon=[]
 hyd_rule= top_db.atom_hybrid_rule() 
 natom= len(atomnms)
 for i in range(natom):
  anm = atomnms[i]
  if not anm[0] in hyd_rule:
   print 'Err! Element', anm[0],'not found in lib:'
   print hyd_rule.key()
   sys.exit()
  this_rule = hyd_rule[anm[0]]
  ndir_con = this_rule[1] - len(conn[i])
  if ndir_con<0 and this_rule[0]=='n':
   print 'There are more conn to', anm,i,'than allowed', this_rule[1]
   sys.exit()
  if ndir_con>2:
   print 'Not enough conn for', anm, i
   sys.exit()
    
  if ndir_con<=0: hnm=anm[0]+'sp3'
  if ndir_con==1: hnm=anm[0]+'sp2'
  if ndir_con==2: hnm=anm[0]+'sp1'
  t.append(hnm)
  ncon.append(ndir_con)
#for N sp2 and sp3 both have 3 direct connectivities
 for i in range(natom):
  anm = atomnms[i]
  if anm[0]=='N' and ndir_con==0:
   t_l=get_nb_nm_list(i,conn, t )
   for j in t_l:
    if 'sp2' in j:
     t[i]='Nsp2'
     break 


#for special
 for i in range(natom):
  anm =atomnms[i]
#this could be either sp2 or sp3
  if anm[0]=='N' and (not 'sp2' in t[i]) and len(conn[i])==3:
   for j in conn[i]:
    if 'sp2' in t[j]:
     t[i]='Nsp2'
     
 return t
  
def get_elewdr_num(atomnms, conn):
 t=[]
 for i, anm in enumerate(atomnms):
  if anm[0]!='C': 
   t.append(-1)
   continue
  c=0
  nb_list = get_nb_nm_list(i, conn, atomnms)
  for j in nb_list:
   if j[0] in ['O','N','P']:
    c+=1
  t.append(c)
 return t  


def get_nb_nm_list(i, conn, atomnms):
 t=[]
 for j in conn[i]:
  t.append(atomnms[j])
 return t


def show_assigned(atomnms, mark, atyp, hyd_state):
 n=len(atomnms)
 for i in range(n):
  if i%5==0: print
  if mark[i]==0: s='***'
  else: s=atyp[i]
  print '%5d%5s%5s%5s'%(i,atomnms[i],hyd_state[i],s),

 print


#to further define type for H to C next to positive charge
def correct_Hpositive_amber(t_typ, conn):
 ii=0
 for i,j in zip(t_typ, conn):
  if i[0]=='H':
   c_i = j[0]
   found=0
   for k in conn[c_i]:
    if t_typ[k] in ['N3']:
     found=1
     break
   if found==1:
    t_typ[ii]='HP'
  ii+=1

