import sys
from PDB import *
from geom import *
from out_put import *
from readFF import *
import top_db
#from readFF import *
#==helper==

def isdiff_float(a,b,small=1e-20, tol=1e-5):
  #print 'ok'
  #print a,b
  c= a-b
  if c<0: c=-c
  if a<0: a=-a
  if b<0: b=-b
  if a<small and b<small: return False
  if a<b:
    d=b
  else:
    d=a
  if d<small:
   if c<small:
     return False
   else:
     return True
  #print c,d,c/d
  if c/d<tol:
    return False
  else:
    return True


def is_float_list_same(a_l, b_l):
 for i,j in zip(a_l,b_l):
  if isdiff_float(i,j): return False
 return True

def isdiff_floatA (a,b,small=1e-20, tol=1e-5):
  c= a-b
  if c<0: c=-c
  if a<0: a=-a
  if b<0: b=-b
  if a<b:
    d=b
  else:
    d=a
  if d<small:
   if c<small:
     return False
   else:
     return True
  if c/d<tol:
    return False
  else:
    return True

def show_missing(c_set2FF, c_typ, c_nm, term, atyp):
 mis=0
 typnm=[i[0] for i in atyp]
 for i,j,k in zip(c_set2FF, c_typ, c_nm):
  if len(i)==0:
   print 'Warning! atoms',
   for l in k: print l,
   print '(type',
   for l in j: print l,
   print ') miss',term,'parameters'
   for l in j:

    print l, atyp[typnm.index(l)][2]
   mis+=1
   print '==================='
   if mis>=3: break
 return mis


def arg_w_opt(arg,opt,defa=None):
 if opt in arg:
  idx=arg.index(opt)
  return arg[idx+1]
 else:
  return defa

def file_2_arg(fnm):
 re=[]
 f=open(fnm,'r')
 for rl in f:
  srl=rl[:-1].split()
  for i in srl:
   re.append(i)
 f.close()
 return re

def check_donor_dist(system,k,acc_list,\
                     excluSig='',excluList=[],cut=3.2):
 
 for iacc in acc_list:
  d=dist(k, iacc[1])
  if d<cut:
   iatom = system.atoms[iacc[0]]
   Sig=  iatom[8]+':'+iatom[2]+str(iatom[4])
   #print Sig, excluSig, iatom[1]
   if Sig==excluSig:
    if iatom[1] in excluList: continue
   print Sig+':'+iatom[1]+'(%.3fA)'%(d),
 

def show_opt_res(opt):
 for j,i in enumerate(opt):
  print j,':',i[1]

def choose_opt(nopt):
 co=-1
 while (co<0 or co>=nopt):
  cs = sys.stdin.readline()
  if not is_idx(cs[:-1]): continue
  co =int(cs[:-1])
 return co


def reset_chain(system, t_p_l=10):
 reslist=system.reslist
 for i in range(t_p_l):
  print '%6s        |'%('idx'),
 print
 for i in range(t_p_l):
  print '--------------+',
 print
 for i,j in enumerate(reslist):
  print '%6d%4s%4d|'%(i,system.atoms[j[0]][2],system.atoms[j[0]][4]),
  if i%t_p_l == t_p_l-1:
   print
 print
 print 'Please input starting residue indices for new chains'
 ok=False
 while not ok:
  co= sys.stdin.readline()
  srl=co[:-1].split()
  c_st=[]
  if len(srl)<1: break
  
  ok = True
  for i in srl:
   if is_idx(i):
    idx= int(i)
    if idx<0 or idx>=len(reslist):
     ok=False
     break
    if not idx in c_st:
     c_st.append(idx)
 c_atom_st=[reslist[i][0] for i in c_st]
 if not 0 in c_atom_st: c_atom_st.append(0)
 c_atom_st.sort()
 c_atom_st.append(len(system.atoms))
 for i in range(len(c_atom_st)-1):
  st = c_atom_st[i]
  en = c_atom_st[i+1]
  for j in range(st, en):
   system.atoms[j][3] = chr(ord('A') + i%26)
   system.atoms[j][8]= 'P'+str(i+1)
 
def set_segname(system):
 
 ok =False
 while not ok:
  c_st = system.get_chains()
  print 'Please choose: [chain id][new segment name] ("quit" to quit)'
  co=sys.stdin.readline()
  if 'quit' in co: break
  srl=co[:-1].split()
  if len(srl)<2: continue
  if not is_idx(srl[0]): continue
  idx=int(srl[0])
  if idx<0: continue
  if idx>=len(c_st): continue
  idx-=1
  for j in range(c_st[idx], c_st[idx+1]):
   system.atoms[j][8] = srl[1]
   

def gen_cys_s_list(system):
 """
 Return list of pair [internal idx of SG atoms, coordinates of SG]
 """
 re=[]
 reslist = system.reslist
 for i in reslist:
  if 'CYS' in i[2]:
   for j in range(i[0], i[0]+i[1]):
    if system.atomnms[j]=='SG':
     re.append([j, system.xp[j]])
 return re

def check_disu (system, cys_S_list, cut=4.0):

 if len(cys_S_list)<2: return []
 re=[]
 for k in range(len(cys_S_list)-1):
  i=cys_S_list[k]
  i_xp = i[1]
  i_idx = i[0]
  for j in range(k+1, len(cys_S_list)):
   l=cys_S_list[j]
   j_xp = l[1]
   j_idx = l[0]
   d = dist(i_xp, j_xp)
   if d<cut:
    sig1 = system.atoms[i_idx][8]+':'+str(system.atoms[i_idx][4])
    sig2 = system.atoms[j_idx][8]+':'+str(system.atoms[j_idx][4])
    print 'Distance between two sulfurs of ',sig1,'and',sig2,'is %.3f'%(d)
    print 'Do you want to build a disulfide bond between them?(y/n)'
   # co = sys.stdin.readline()
   # if 'n' in co or 'N' in co:
   #  continue
    print 'A disulfide bond found between',sig1, 'and',sig2
    re.append([sig1,sig2])
 return re

#interactively choose parameters for FF entry
#to see if parameters are in FF already
#also type is defined
def input_param(FF, atyp):

# interact_rule={\
# type   maptoFF    tobedef   alreadyHave 
# 'BOND':[c12_set2FF, c12_typ, FF.bonds,[2,2]],\
# 'ANGLE':[c13_set2FF, c13_typ, FF.angles,[3,2]],\
# 'DIHE':[c14_set2FF, c14_typ, FF.dihes,[4,3]],\
# 'IMPR':[cim_set2FF, cim_typ, FF.imprs,[4,3]]\
# }
 interact_rule = top_db.load_param_charmm_rule(FF)
 typnm=[i[0] for i in atyp]
 cmd=''
 while not 'quit' in cmd:
  print 'Enter: BONDS/ANGLES/DIHEDRALS/IMPROPER  typee1 type2 [type3 type4] parameter...'
  print 'Use wild card * to show available FF entries'
  print '"quit" to quit'
  cmd=sys.stdin.readline()
  if 'quit' in cmd: return
  srl=cmd[:-1].split() 
  if not srl[0] in interact_rule: continue
  c_rule=interact_rule[srl[0]]
  flist = c_rule[0]
  param = c_rule[2]
  col_field1 = c_rule[1] #column field for name
  col_field2 = len(c_rule[3]) #for param
  if '*' in cmd:
   if len(srl)-1<col_field1:
    print 'Not enough info'
    continue
   t_l=[srl[i] for i in range(1,1+col_field1)]
   mlist=match_FF_list_wildcard(t_l, flist)
   if len(mlist)==0:
    print 'Nothing found'
   else:
    for i in mlist:
     print 'type',flist[i],'parameter',param[i] 
   continue
  if len(srl)-1<col_field1+col_field2:
   print 'Not enough info'
   continue
  t_l=[srl[i] for i in range(1,1+col_field1)]
  mlist = match_FF_list(t_l, flist)
  if len(mlist)>0:
   print 'type',t_l,'already defined'
   continue
  flist.append(t_l)
  d_typ=c_rule[4]
  p_l=[]
  for i,j in enumerate(d_typ):
   ss=srl[1+col_field1+i]
   if j==0:
    p_l.append(int(ss))
   else:
    p_l.append(float(ss))
  param.append(p_l)


def len_comment(srl, c_list=[';','!','#']):
 if len(srl)==0: return 0
 for i,j in  enumerate(srl):
  if j[0] in c_list: return i
 return len(srl)
