#! /bin/env python

#===============================================================================#
#                    REQUIRED DEFINITION & IMPORTATION                          #
#===============================================================================#
#Module importation .............................................................
try  :
	 import sys, os, getopt
	 from math import sqrt, cos, fabs
	 from UserDict import UserDict
except :
	sys.exit("> Importation error : sys | os | getopt | math | UserDict\n")

try    : from MDTools import md
except : sys.exit("> Importation error : MDTools.md\n")

#Tricky Functions ...............................................................
def makedict(**kwargs):
	return kwargs

def gplot(kwlist=[], kwid=[], datafile=None, gpfile=None, ColByHeader=1, ColId=1) :
	gpfile.write("\nplot\\\n")
	for i in kwid :
		col = i*ColByHeader + 2 + ColId -1
		gpfile.write("   '%s.txt' using 1:%d title '%s' with lines"%(datafile, col, kwlist[i]))
		if i != kwid[-1] : gpfile.write(",\\\n")
		else             : gpfile.write("\n")

#Default atoms parameters ......................................................
alphabet = ('ALA', 'CYS', 'ASP', 'GLU', 'PHE', 'GLY', 'HIS', 'HSD', 'HSE', 'ILE', 'LYS', 'LEU',
            'MET', 'ASN', 'PRO', 'GLN', 'ARG', 'SER', 'THR', 'VAL', 'TRP', 'TYR', 'UNK', 'XXX')

three2one_letter_code = makedict(
	ALA='A', CYS='C', ASP='D', GLU='E', PHE='F', GLY='G', HIS='H', ILE='I',
	LYS='K', LEU='L', MET='M', ASN='N', PRO='P', GLN='Q', ARG='R', SER='S',
	THR='T', VAL='V', TRP='W', TYR='Y', HSD='H', HSE='H', UNK='X', XXX='X')

vdWdiam = makedict(
	C12 =2.060, H21 =0.700, H22 =0.700, H23 =0.700, C13 =2.060, H31 =0.700, H32 =0.700, H33 =0.700,
	C14 =2.060, H41 =0.700, H42 =0.700, H43 =0.700, C11 =2.275, H11 =0.700, H12 =0.700, N   =1.850,
	C15 =2.010, H51 =1.340, H52 =1.340, P1  =2.150, O1  =1.770, O2  =1.770, O3  =1.700, O4  =1.700,
	C1  =2.010, HA  =1.340, HB  =1.340, C2  =2.275, HS  =1.320, C21 =2.000, O21 =1.770, O22 =1.700,
	C22 =2.010, H2R =1.340, H2S =1.340, C23 =2.206, C24 =2.206, C25 =2.206, C26 =2.206, C27 =2.206,
	C28 =2.206, C29 =2.105, C210=2.105, C211=2.206, C212=2.206, C213=2.206, C214=2.206, C215=2.206,
	C216=2.206, C217=2.206, C218=2.206, C3  =2.010, HX  =1.340, HY  =1.340, C31 =2.000, O31 =1.770,
	O32 =1.700, C32 =2.010, H2X =1.340, H2Y =1.340, C33 =2.206, C34 =2.206, C35 =2.206, C36 =2.206,
	C37 =2.206, C38 =2.206, C39 =2.206, C310=2.206, C311=2.206, C312=2.206, C313=2.206, C314=2.206,
	C315=2.206, C316=2.206, C   =1.992, O   =1.700, H   =1.320)

AtomGroup = makedict(
	backbone  = ('CA'  ,'C'    ,'O'    ,'N'    ,'HN'  ,'HT1'   ,'HT2' ,'HT3'   ,'OT1'  ,'OT2'  ,'HA' ,'HA1','HA2'),
	amine     = ('N'   ,'C12'  ,'H21'  ,'H22'  ,'H23' ,'C13'   ,'H31' ,'H32'   ,'H33'  ,'C14'  ,'H41','H42','H43'),
	phosphate = ('O1'  ,'O2'   ,'O3'   ,'O4'   ,'P1'),
	glycerol  = ('C2'  ,'C21'  ,'O21'  ,'O22'  ,'HS'   ,'C3'   ,'C31'  ,'O31'  ,'O32'  ,'HX'   ,'HY'),
	ch2       = ('C22' , 'C23' , 'C24' , 'C25' , 'C26' , 'C27' , 'C28' , 'C29' , 'C210', 'C211',
	             'C212', 'C213', 'C214', 'C215', 'C216', 'C217', 'C32' , 'C33' , 'C34' , 'C35' ,
	             'C36' , 'C37' , 'C38' , 'C39' , 'C310', 'C311', 'C312', 'C313', 'C314', 'C315'),
	ch3       = ('C218', 'C316'),
	oleyl     = ('C21' , 'C22' , 'C23' , 'C24' , 'C25' , 'C26' , 'C27' , 'C28' , 'C29' , 'C210', 'C211',
	             'C212', 'C213', 'C214', 'C215', 'C216', 'C217', 'C218'),
	palmitoyl = ('C31' , 'C32' , 'C33' , 'C34' , 'C35' , 'C36' , 'C37' , 'C38' , 'C39' , 'C310', 'C311',
	             'C312', 'C313', 'C314', 'C315', 'C316'),
	water     = ('H1'  ,'H2'   ,'OH2'))

IdGroup = makedict(
	layer1 = range(1 , 82), layer2 = range(82, 163),
	hphobe = [], hphile = [], gly = [], pro = [],
	hlx = [], nothlx=[])

#Script Header and Usage ........................................................
NAME  = "%-15s"%"[ NAMDarea ]"
USAGE = """
SYNOPSIS :
==========
   - This script provide methods to compute lipid area through a CHARMM/NAMD
     trajectory. Computed values are time and molecule average.

   - MDtools Python module is required! Sources available at :
     http://www.ks.uiuc.edu/~jim/mdtools/

   - This script is awfully coded, do not hesitate to improve it!

USAGE :
=======
     NAMDarea.py [OPTION] --psf=<PSF file> --pdb=<PDB file> --dcd=<DCD file>

OPTION :
========
   -h|--help          Print this help message

   -p|--project       Name of the project used for file naming

   -m|--modulo        increment through the trajectory.
                      Default is 1.

   -s|--size          Grid dimension, i.e. <length x width x height>. Default
                      is <70 x 70 x 100> angstrom.
                      !! BE CAREFULL, grid size have to be broader than Cell   !!
                      !! dimensions to vaoid truncation effect. Default values !!
                      !! are suitable for a cell of 60x60x90                   !!

   --res              Grid resolution use to compute the atom area.
                      Default is 0.5 A/node.

   -o|--outdir        Output directory. Default is current directory

   --psf              Topology parameter file in CHARMM or XPLOR format

   --pdb              Initial coordinates of the system in PDB format

   --dcd              Trajectory file in CHARMM format (binary file)
"""
#===============================================================================#
#                        INTERNAL OBJECT DEFINITION                             #
#===============================================================================#
class ScriptConfig : #==========================================================#
	""" ScriptConfig is defining the script configuration including :
	    - path to input files
	    - output name and directory
	    - script-specific computation parameters : increment, time-step, ..."""

	def __init__(self) : #..................................................#
		"""Initialize Script paramaters. Some of them can be user defined"""
		#Parameters that can be defined by the user
		self.tstep     = 200.0
		self.modulo    = 1
		self.nframe    = 0
		self.project   = "foo"
		self.outdir    = os.getcwd()
		self.xscdir    = "%s/XSC"%os.getcwd()
		self.boxsize   = [70.0, 70.0, 100.0]
		self.res       = 0.5
		#static parameters
		self.nframe    = 0
		self.filetemp  = "foo"
		self.sequence  = []
		self.seqstring = ""
		self.pid       = os.getpid()

	def setopt(self, argv) : #..............................................#
		"""Set user-defined options"""
		if len(argv) == 1 :
			sys.exit("%s -- Bad usage : try [-h|--help] option for help\n\n"%NAME)
		try:
			opts, args = getopt.getopt(sys.argv[1:], "hs:o:p:m:",
			             ["help", "project=", "size=", "modulo=",
			              "res=", "offset=", "outdir=", "pdb=", "psf=", "dcd="])

		except getopt.GetoptError, error:
			sys.exit("%s Bad option usage : %s\n"%(NAME, error))

		else :
			for o, a in opts:
				if o in ("-h", "--help")       : sys.exit(USAGE)
				elif o in ("-m", "--modulo")     : self.modulo     = int(a)
				elif o in ("-p", "--project")    : self.project    = str(a)
				elif o in ("-o", "--outdir")     : self.outdir     = os.path.normpath(str(a))
				elif o in ("-r", "--res")        : self.res        = float(a)
				elif o == "--psf"                : self.psffile    = os.path.normpath(str(a))
				elif o == "--pdb"                : self.pdbfile    = os.path.normpath(str(a))
				elif o == "--dcd"                : self.dcdfile    = os.path.normpath(str(a))
				elif o in "--size"               :
					i = 0
					for word in a.split('x') :
						self.boxsize[i] = float(word)
						i += 1
				else : sys.exit("> Bad option usage : [%s %s], try [-h|--help] option for help\n"%(o, a))

		#set file name and path
		if not os.path.isdir(self.outdir) : sys.exit("\t- Output is not a directory : <%s>\n"%self.outdir)
		else : self.outdir = os.path.abspath(self.outdir)
		self.psffile  = os.path.abspath(self.psffile)
		self.pdbfile  = os.path.abspath(self.pdbfile)
		self.dcdfile  = os.path.abspath(self.dcdfile)
		self.filetemp = "%s_area"%self.project
		#convert time step in ns
		self.tstep = self.tstep * 1e-6 * self.modulo

	def loadtraj(self) : #..................................................#
		"""Load all required files (PDB, PSF and DCD) and return System Definition and Trajectory"""
		try :
			open(self.psffile)
			open(self.pdbfile)
			open(self.dcdfile)
		except IOError, file:
			sys.exit("%s -- File can't be opened : <%s>\n"%(NAME, file.filename))

		try    :
			MyDef = md.Molecule(pdb=self.pdbfile, psf=self.psffile)
		except :
			sys.exit("%s -- Structure definition can't be loaded :\n pdb = <%s>\n psf = <%s>\n"%(NAME,self.pdbfile,self.psffile))

		try    :
			MyTraj = md.DCD(self.dcdfile)
			self.nframe = int(round(MyTraj.NSET / self.modulo, 0))
			print "%s %d frames %% modulo %d = %d frames to analyze"%(NAME, MyTraj.NSET, self.modulo, self.nframe)
			print "%s time step between 2 frames = %d fs"%(NAME, self.tstep)
			print "%s Number of atoms = %d"%(NAME, MyTraj.numatoms)

		except :
			sys.exit("%s -- Trajectory can't be loaded :\n%s dcd = <%s>\n"%(NAME, NAME, self.dcdfile))
		else :
				return MyDef, MyTraj

	def getseq(self, MyDef) : #.............................................#
		"""get sequence from a System Definition"""
		Peptide = MyDef.rsel(lambda res: res.name in alphabet)
		self.sequence   = []
		self.seqstring = ""
		for i in range(len(Peptide.residues)):
			self.sequence.append(Peptide.residues[i].name)
			self.seqstring += three2one_letter_code.get(self.sequence[-1], 'X')

		self.nres = len(self.sequence)


class Node : #==================================================================#
	"""A class defining a node belonging to a Grid object """
	def __init__(self, x=0.00, y=0.00) :
		"""Object Initialization"""
		self.x         = x #coordinates on the grid
		self.y         = y
		self.occupancy = 0 #occupancy : 0=not occupied | 1=occupied

	def getdist(self, atom) :
		"""Compute distance between the node and a atom center"""
		distance = 0.00
		distance = sqrt((atom.x - self.x)*(atom.x - self.x) + (atom.y - self.y)*(atom.y - self.y))
		return distance


class Grid : #==================================================================#
	"""A class defining the grid use to compute lipid area"""
	def __init__(self, l=0.0, w=0.0, res=0.0):
		"""Object initialization. A grid is characterized by :
			- a size
			- an origin
			- a resoltion in number of points

		   A grid is composed of nodes characterized by :
		   	- a couple of grid coordinates
			- their occupancy by a van der waals radius (0 or 1)"""

		self.res    = res    #....... grid resolution
		self.pixel  = res*res
		self.l      = l
		self.w      = w
		self.Ox     = -round(l/2, 0)
		self.Oy     = -round(w/2, 0)
		self.nx     = int(round(l/res, 0)) + 1
		self.ny     = int(round(l/res, 0)) + 1
		self.node   = []

		self.node = [0] * self.nx
		for i in range(self.nx):
			self.node[i] = [0] * self.ny
			for j in range(self.ny):
				#Set the node
				self.node[i][j] = Node(x=(self.Ox + i*self.res), y=(self.Oy + j*self.res))

	def set(self, atomlist, vdw):
		"""Fill the grid with atom vdW radius"""

		area = 0.0
		npix = 0.00
		for atom in atomlist:

			radius = vdw.get(atom.name, 0.0) / 2.0

			xmin = int(round((atom.x - 3.0 - self.Ox) / self.res, 0)) - 1
			xmax = int(round((atom.x + 3.0 - self.Ox) / self.res, 0)) + 1
			ymin = int(round((atom.y - 3.0 - self.Oy) / self.res, 0)) - 1
			ymax = int(round((atom.y + 3.0 - self.Oy) / self.res, 0)) + 1

			if xmin < 0       : xmin = 0
			if xmax > self.nx : xmax = self.nx
			if ymin < 0       : ymin = 0
			if ymax > self.ny : ymax = self.ny

			for x in range(xmin, xmax):
				for y in range(ymin, ymax):
					if self.node[x][y].getdist(atom) <= radius and self.node[x][y].occupancy == 0:
							self.node[x][y].occupancy += 1
							npix += 1.0

		return npix * self.pixel



	def unset(self):
		"""Reset grid"""
		for x in range(self.nx):
			for y in range(self.ny):
				self.node[x][y].occupancy = 0

	def getarea(self):
		"""Compute area of occupied nodes"""
		area = 0.000
		npix = 0.0
		for x in range(self.nx):
			for y in range(self.ny):
				if self.node[x][y].occupancy != 0 : npix += 1.0

		print "%d x %f = %f "%(npix, self.pixel, npix*self.pixel)

		return npix*self.pixel


class TrajTracking : #==========================================================#
	"""A class defining metrics to compute through a trajectory on an atom selection"""

	def __init__(self, sel=None, nframe=0) :
		"""Object initialization"""
		self.sel   = sel
		self.area  = [0.000] * nframe #total area
		self.avg   = [0.000, 0.000]   #average area


	def setstep(self, step=0, grid=None, vdw=None, nres=1) :
		"""Compute area per lipid using a grid :
			1 - Atoms are projected in the XY plane
			2 - Area is computed using box grid occupied
			    by a vdW diameter?
			3 - vdW diameter definition is required for each
			    atom selection
		"""

		self.area[step] = grid.set(self.sel.atoms, vdw) / float(nres)
		grid.unset()


	def getaverage(self):
		"""Compute avg area per residue and its associated
		standard deviation for each frame"""

		nframe = len(self.area)
		sum    = 0.0
		for i in range(nframe) :
			sum  += self.area[i]


		self.avg[0] = sum / nframe
		sum = 0.000

		for i in range(nframe) :
			sum  += (self.area[i] - self.avg[0]) * (self.area[i] - self.avg[0])

		self.avg[1] = sqrt(sum/nframe)



#===============================================================================#
#                                   MAIN                                        #
#===============================================================================#
#Configure the script ...........................................................
conf = ScriptConfig()
conf.setopt(sys.argv)

#load system structure and trajectory ...........................................
print "%s Loading trajectory %s"%(NAME, '.'*25)
SysDef, SysTraj = conf.loadtraj()

#Get protein sequence & set interesting positions ...............................
conf.getseq(SysDef)
print "%s sequence is including %d residues :\n%s %s\n%s"%(NAME, conf.nres, NAME, conf.seqstring, NAME)


#Select residues of interest ....................................................
print "%s Atom Selection "%(NAME)
headerlist = ['layer1','layer2']
atomselect = makedict(
	layer1 = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer1']),
	layer2 = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer2'])
	#head1  = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer1']).asel(lambda atm : atm.name not in AtomGroup['glycerol'] and atm.name not in AtomGroup['ch2'] and atm.name not in AtomGroup['ch3']),
	#head2  = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer2']).asel(lambda atm : atm.name not in AtomGroup['glycerol'] and atm.name not in AtomGroup['ch2'] and atm.name not in AtomGroup['ch3'])
	#gly1   = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer1']).asel(lambda atm : atm.name in AtomGroup['glycerol']),
	#gly2   = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer2']).asel(lambda atm : atm.name in AtomGroup['glycerol']),
	#acyl1  = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer1']).asel(lambda atm : atm.name in AtomGroup['ch2'] or atm.name in AtomGroup['ch3']),
	#acyl2  = SysDef.rsel(lambda res: res.name == "POUN" and res.id in IdGroup['layer2']).asel(lambda atm : atm.name in AtomGroup['ch2'] or atm.name in AtomGroup['ch3'])
)

for sel in headerlist :
	print "%s %-9s = %9d atoms "%(NAME, sel, len(atomselect[sel].atoms))

#SysRef = SysDef.rsel(lambda res: res.name == "POUN").asel(lambda atm : atm in AtomGroup['glycerol'])


#VARIABLE INITIALIZATION =======================================================
tracked = {}
for selected in atomselect.keys() :
	tracked.setdefault(selected, TrajTracking(sel=atomselect[selected], nframe=conf.nframe))

#Set grid
mygrid = Grid(l=conf.boxsize[0], w=conf.boxsize[1], res=conf.res)
print "%s grid lenght     = %.3f A"%(NAME, mygrid.l)
print "%s grid width      = %.3f A"%(NAME, mygrid.w)
print "%s grid resolution = %.3f A/point"%(NAME, mygrid.res)


#TRAJECTORY ANALYSIS ============================================================
print "%s Trajectory analysis %s"%(NAME, '.'*24)
SysDef.saveframe()
for i in range(conf.nframe) :

	SysDef.getframe(SysTraj[i*conf.modulo])
	for selection in tracked.keys() :
		tracked[selection].setstep(step=i, grid=mygrid, vdw=vdWdiam, nres=len(IdGroup['layer1']))

	print "%s Frame %06d : completed"%(NAME, i*conf.modulo)


#avg AREA COMPUTATION ==========================================================
print "%s area/lipid computation %s"%(NAME, '.'*27)
for selection in tracked.keys() :
	tracked[selection].getaverage()


# OUPUT =========================================================================
#Writing text file ..............................................................
print "%s Writing Output Files %s"%(NAME, '.'*39)
print "%s output directory : <%s>"%(NAME, conf.outdir)
print "%s atom density reported in <%s.txt>"%(NAME, conf.filetemp)
outfile = open("%s/%s.txt"%(conf.outdir, conf.filetemp), 'w')

#header
outfile.write("#%8s "%('time'))
for name in headerlist :
	outfile.write("%9s "%name)
#values
for i in range(conf.nframe):
	outfile.write("\n%9.6f "%(i*conf.tstep))
	for name in headerlist :
		outfile.write("%9.3f "%(tracked[name].area[i]))
outfile.write("\n")
outfile.close()

#Writing gnuplot file ...........................................................
print "%s gnuplot file : <%s.gp>\n"%(NAME, conf.filetemp)

gpfile = open("%s/%s.gp"%(conf.outdir, conf.filetemp), 'w')
gpfile.write("set terminal postscript solid colour eps\n")
gpfile.write("set key below\n")
gpfile.write("set xzeroaxis 1\n")
gpfile.write("set grid ytics\n")
gpfile.write("\nset output '%s.eps'\n"%conf.filetemp)

#Plot area per lipid ............................................................
gpfile.write("set xlabel 'time (ns)'\n")
gpfile.write("set ylabel 'area (angstrom^2)'\n")
gpfile.write("set origin 0, 0\n")
gpfile.write("set size 1.5, 1.5\n")
gpfile.write("set multiplot\n")
#layer
gpfile.write("set title 'area per lipid'\n")
gpfile.write("set origin 0, 0\n")
gpfile.write("set size 0.75, 0.75\n")
gpfile.write("\nplot\\\n")
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines,\\\n"%(conf.filetemp, 2, 'upper'))
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines\n"%(conf.filetemp, 3, 'lower'))
#head groups
gpfile.write("set title 'area per head group'\n")
gpfile.write("set origin 0.75, 0.0\n")
gpfile.write("set size 0.75, 0.75\n")
gpfile.write("\nplot\\\n")
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines,\\\n"%(conf.filetemp, 4,'upper'))
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines\n"%(conf.filetemp, 5, 'lower'))
#glycerol
gpfile.write("set title 'area per glycerol'\n")
gpfile.write("set origin 0.0, 0.75\n")
gpfile.write("set size 0.75, 0.75\n")
gpfile.write("\nplot\\\n")
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines,\\\n"%(conf.filetemp, 4,'upper'))
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines\n"%(conf.filetemp, 5, 'lower'))
#glycerol
gpfile.write("set title 'area per acyl chain'\n")
gpfile.write("set origin 0.75, 0.75\n")
gpfile.write("set size 0.75, 0.75\n")
gpfile.write("\nplot\\\n")
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines,\\\n"%(conf.filetemp, 6,'upper'))
gpfile.write("   '%s.txt' using 1:%d title '%s' with lines\n"%(conf.filetemp, 7, 'lower'))
gpfile.write("unset multiplot\n")
gpfile.close()
