#!/usr/bin/env python

import sys

def quit_early(fout, outfile):
  import os
  fout.close()
  #os.remove(outfile)
  print >> sys.stderr, "terminating prematurely"
  sys.exit()


def polarize_tip3_water_structure(infile, outfile, q_d, m_d):
  """
  Convert PSF structure for each TIP3 water molecule into PSPC water
  with an added Drude particle, renumbering accordingly
  """

  fin = open(infile, "r")
  fout = open(outfile, "w")

  while True:
    line = fin.readline()
    if line[9:15] == "!NATOM": break
    fout.write(line)

  q_o = -0.669
  q_h = -0.5 * q_o
  qstr = [ " %10.6f" % (q_o - q_d), " %10.6f" % q_h, " %10.6f" % q_h ]

  natoms = int(line[:8])
  atom = [None]*natoms
  amap = [None]*natoms
  offset = 0
  watlinecnt = 0
  watstr = [ "TIP3 OH2  OT  ", "TIP3 H1   HT  ", "TIP3 H2   HT  " ]
  newatstr = [ "PSPC OH2D ODRU", "PSPC H1D  HDRU", "PSPC H2D  HDRU" ]
  mstr = ""
  qdstr = " %10.6f" % q_d
  mdstr = " %13.5f" % m_d
  drudestr = "PSPC DRU  DRU "
  watmol = []

  for i in range(natoms):
    atom[i] = fin.readline()
    amap[i] = i + offset
    if atom[i][19:33] == watstr[watlinecnt]:
      if watlinecnt == 2:
        watmol.append(i-2)
        offset += 1
        watlinecnt = 0
      else:
        watlinecnt += 1
    elif watlinecnt > 0 or atom[i][19:23] == "TIP3":
      print >> sys.stderr, "improper format for TIP3 water"
      quit_early(fout, outfile)

  if watlinecnt != 0:
    print >> sys.stderr, "ran out of atoms while looking for last water"
    quit_early(fout, outfile)

  if natoms + offset >= 10000000:
    print >> sys.stderr, "too many atoms to store in PSF"
    quit_early(fout, outfile)
  else:
    line = "%8d" % (natoms + offset) + line[8:]
    fout.write(line)

  for i in range(natoms):
    if atom[i][19:33] == watstr[watlinecnt]:
      if watlinecnt == 0:
        mstr = " %13.4f" % (float(atom[i][45:58]) - m_d)
      else:
        mstr = atom[i][44:58]
      line = "%8d" % (amap[i] + 1) + atom[i][8:19] \
          + newatstr[watlinecnt] + qstr[watlinecnt] + mstr + atom[i][58:]
      fout.write(line)
      if watlinecnt == 2:
        line = "%8d" % (amap[i] + 2) + atom[i-2][8:19] \
            + drudestr + qdstr + mdstr + atom[i-2][58:]
        fout.write(line)
        watlinecnt = 0
      else:
        watlinecnt += 1
    else:
      line = "%8d" % (amap[i] + 1) + atom[i][8:]
      fout.write(line)

  while True:
    line = fin.readline()
    if line[9:15] == "!NBOND": break
    fout.write(line)

  nbonds = int(line[:8])
  nwatmols = len(watmol)
  if nbonds + nwatmols >= 10000000:
    print >> sys.stderr, "too many bonds to store in PSF"
    quit_early(fout, outfile)
  else:
    line = "%8d" % (nbonds + nwatmols) + line[8:]
    fout.write(line)

  bond = [None]*nbonds
  n = 0
  while n < nbonds:
    line = fin.readline()
    if nbonds - n >= 4:
      bond[n] = [ int(line[0:8])-1, int(line[8:16])-1 ]
      bond[n+1] = [ int(line[16:24])-1, int(line[24:32])-1 ]
      bond[n+2] = [ int(line[32:40])-1, int(line[40:48])-1 ]
      bond[n+3] = [ int(line[48:56])-1, int(line[56:64])-1 ]
      n += 4
    elif nbonds - n >= 3:
      bond[n] = [ int(line[0:8])-1, int(line[8:16])-1 ]
      bond[n+1] = [ int(line[16:24])-1, int(line[24:32])-1 ]
      bond[n+2] = [ int(line[32:40])-1, int(line[40:48])-1 ]
      n += 3
    elif nbonds - n >= 2:
      bond[n] = [ int(line[0:8])-1, int(line[8:16])-1 ]
      bond[n+1] = [ int(line[16:24])-1, int(line[24:32])-1 ]
      n += 2
    elif nbonds - n >= 1:
      bond[n] = [ int(line[0:8])-1, int(line[8:16])-1 ]
      n += 1

  drbond = [None]*(nbonds+nwatmols)
  n = 0
  m = 0
  while n < nbonds and m < nwatmols:
    if bond[n][0] == watmol[m]:
      if nbonds - n < 2 or bond[n+1][0] != watmol[m]:
        print >> sys.stderr, "more bonds expected"
        quit_early(fout, outfile)
      else:
        drbond[n+m] = [ bond[n][0] + m, bond[n][1] + m ]
        drbond[n+m+1] = [ bond[n+1][0] + m, bond[n+1][1] + m ]
        drbond[n+m+2] = [ watmol[m] + m, watmol[m]+3 + m ]
        m += 1
        n += 2
    else:
      drbond[n+m] = [ bond[n][0] + m, bond[n][1] + m ]
      n += 1

  if len(drbond) != nbonds + nwatmols:
    print len(drbond), nbonds, nwatmols
    print >> sys.stderr, "did not find bonds for all water molecules"
    quit_early(fout, outfile)

  line = ""
  cnt = 0
  for b in drbond:
    line += "%8d%8d" % (b[0]+1, b[1]+1)
    cnt += 1
    if cnt == 4:
      line += "\n"
      fout.write(line)
      line = ""
      cnt = 0
  if cnt != 0:
    line += "\n"
    fout.write(line)

  # deal with angles
  while True:
    line = fin.readline()
    fout.write(line)
    if line[9:16] == "!NTHETA": break

  nangles = int(line[:8])
  n = nangles//3
  m = nangles % 3
  for i in range(n):
    line = fin.readline()
    for j in range(9):
      k = int(line[8*j:8*(j+1)]) - 1
      line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
    fout.write(line)

  line = fin.readline()
  for j in range(3*m):
    k = int(line[8*j:8*(j+1)]) - 1
    line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
  fout.write(line)

  # deal with dihedrals
  while True:
    line = fin.readline()
    fout.write(line)
    if line[9:14] == "!NPHI": break

  ndiheds = int(line[:8])
  n = ndiheds//2
  m = ndiheds % 2
  for i in range(n):
    line = fin.readline()
    for j in range(8):
      k = int(line[8*j:8*(j+1)]) - 1
      line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
    fout.write(line)

  line = fin.readline()
  for j in range(2*m):
    k = int(line[8*j:8*(j+1)]) - 1
    line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
  fout.write(line)

  # deal with impropers
  while True:
    line = fin.readline()
    fout.write(line)
    if line[9:16] == "!NIMPHI": break

  nimprs = int(line[:8])
  n = nimprs//2
  m = nimprs % 2
  for i in range(n):
    line = fin.readline()
    for j in range(8):
      k = int(line[8*j:8*(j+1)]) - 1
      line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
    fout.write(line)

  line = fin.readline()
  for j in range(2*m):
    k = int(line[8*j:8*(j+1)]) - 1
    line = line[:8*j] + "%8d" % (amap[k] + 1) + line[8*(j+1):]
  fout.write(line)

  # ignore donors and acceptors - don't know what to do with them!
  while True:
    line = fin.readline()
    fout.write(line)
    if line[9:13] == "!NNB": break

  nnb = int(line[:8])
  # can't handle any but zero for now, bail on anything else
  if nnb != 0:
    print >> sys.stderr, "unable to handle explicit nonbonded exclusions"
    quit_early(fout, outfile)
  line = fin.readline()
  fout.write(line)

  n = natoms//8
  m = natoms % 8
  for i in range(n):
    line = fin.readline()
    for j in range(8):
      k = int(line[8*j:8*(j+1)])
      line = line[:8*j] + "%8d" % amap[k] + line[8*(j+1):]
    fout.write(line)

  line = fin.readline()
  for j in range(m):
    k = int(line[8*j:8*(j+1)])
    line = line[:8*j] + "%8d" % amap[k] + line[8*(j+1):]
  j += 1
  n = nwatmols
  while n > 0:
    if j == 8:
      line += "\n"
      fout.write(line)
      line = ""
      j = 0
    line = line[:8*j] + "%8d" % 0
    j += 1
    n -= 1
  line += "\n"
  fout.write(line)

  # copy everything else verbatim
  for line in fin:
    fout.write(line)

  return


if __name__ == "__main__":
  if len(sys.argv) != 3:
    print >> sys.stderr, \
"""
Convert PSF structure for each TIP3 water molecule into PSPC water
with an added Drude particle, renumbering accordingly
"""
    print >> sys.stderr, "syntax: %s psf1 psf2" % sys.argv[0]
    sys.exit()

  # eventually parse from command line options:
  arg_q_d = 2.08241
  arg_m_d = 0.8

  polarize_tip3_water_structure(sys.argv[1], sys.argv[2], arg_q_d, arg_m_d)
