"""AtomGroup Hierarchy of MDTools

RCS: $Id: md_AtomGroup.py,v 0.70 1996/11/16 21:30:21 jim Exp $

Class Hierarchy:
   AtomGroup -> ASel
             -> Residue
             -> ResidueGroup -> RSel
                             -> Segment
                             -> SegmentGroup -> Molecule
"""

_RCS = "$Id: md_AtomGroup.py,v 0.70 1996/11/16 21:30:21 jim Exp $"

# $Log: md_AtomGroup.py,v $
# Revision 0.70  1996/11/16 21:30:21  jim
# changed Numeric.Core to Numeric
#
# Revision 0.69  1996/10/12 20:10:54  jim
# added id field to Segment
#
# Revision 0.68  1996/10/12 18:52:09  jim
# update for NumPy1.0a4
#
# Revision 0.67  1996/08/28 20:51:23  jim
# Added masses(), charges(), and coordinate() methods.
#
# Revision 0.66  1996/08/28 19:27:59  jim
# Switched frames from lists of tuples to arrays.
#
# Revision 0.65  1996/05/24 01:20:46  jim
# Split into sub-modules, improved version reporting.
#

print "- AtomGroup "+"$Revision: 0.70 $"[11:-1]+"$State: Exp $"[8:-1]+"("+"$Date: 1996/11/16 21:30:21 $"[7:-11]+")"

import string
import math
import struct
import copy
import types
import tempfile
import os
import sys
import time
from Numeric import array, zeros

from md_local import pdbview
from md_Constants import *
from md_HomoCoord import *
from md_Trans import *

#
# AtomGroup class hierarchy:
#                                        AtomGroup -------------
#                                         |     |              |
#                                    Residue   ResidueGroup   ASel
#                                              |    |     | 
#                                        Segment Molecule RSel
#

class AtomGroup:
	"""A group of atoms.

Data: atoms, [frames]

Methods:
   g = AtomGroup()
   g.atoms.append(a)
   g.masses() - array of masses
   g.tmass() - total mass
   g.charges() - array of charges
   g.tcharge() - total charge
   g.cgeom() - center of geometry
   g.cmass() - center of mass
   g.rgyration() - radius of gyration
   g.saveframe([key]) - save coordinates to internal dictionary
   g.loadframe([key]) - get coordinates from internal dictionary
   g.delframe([key]) - remove coordinates from internal dictionary
   frame = g.coordinates() - return array of coordinates
   g.putframe(frame) - fill array with coordinates
   g.getframe(frame) - get coordinates from array
   frame = mol.putframe()
   g.getmolframe(frame) - get coordinates from list for all atoms in molecule
   g.asel(func) - return atom selection based on filter function

See also: Molecule, ASel
"""
	def __init__(self):
		self.atoms = []
	def masses(self):
		return array(map(lambda a: a.mass, self.atoms))
	def tmass(self):
		return reduce(lambda t,a: t+a.mass, self.atoms, 0.)
	def charges(self):
		return array(map(lambda a: a.charge, self.atoms))
	def tcharge(self):
		return reduce(lambda t,a: t+a.charge, self.atoms, 0.)
	def cgeom(self):
		t = reduce(lambda t,a: t+a, self.atoms, Vector(0,0,0))
		return t / t.W
	def cmass(self):
		t = reduce(lambda t,a: t + a.mass*a, self.atoms, Vector(0,0,0))
		return t / t.W
	def rgyration(self):
		t = reduce(lambda t,a,com = self.cmass():
			t + a.mass*distsq(a,com), self.atoms, 0.)
		return math.sqrt( t / self.tmass() )
	def saveframe(self,key=None):
		if not hasattr(self,'frames'): self.frames = {}
		self.frames[key] = self.coordinates()
	def loadframe(self,key=None):
		if not hasattr(self,'frames') or not len(self.frames):
			raise "no frames saved internally"
		self.getframe(self.frames[key])
	def delframe(self,key=None):
		if not hasattr(self,'frames') or not len(self.frames):
			raise "no frames saved internally"
		del(self.frames[key])
	def getframe(self,frame):
		i = 0
		for a in self.atoms:
			(a.x,a.y,a.z) = tuple(frame[i])
			i = i + 1
	def getmolframe(self,frame):
		for a in self.atoms:
			(a.x,a.y,a.z) = tuple(frame[a.id-1])
	def putframe(self,frame):
		frame[:,:] = map(lambda a: (a.x,a.y,a.z), self.atoms)
	def coordinates(self):
		return array(map(lambda a: (a.x,a.y,a.z), self.atoms))
	def asel(self,func):
		return ASel(self,func)
	def __repr__(self):
		return '< '+self.__class__.__name__+' with '\
			+`len(self.atoms)`+' atoms >'

class Residue(AtomGroup):
	"""A group of atoms with extra information.

Data: type, name, id, segment, prev, next

Methods:
   r = Residue()
   r.buildrefs() - assigns residue for atoms (done by Molecule)
   r.delrefs() - removes references to allow deletion (done by Molecule)
   r[name] - returns atoms by name (like a dictionary)
   r.rotate(angle,[units]) - rotate side chain
   r.phipsi([units]) - returns (phi,psi)

See also: Atom, Molecule, 'angles'
"""
	def __init__(self):
		AtomGroup.__init__(self)
		self.type = '???'
		self.name = '???'
		self.id = 0
		self.segment = None
		self.prev = None
		self.next = None
	def buildrefs(self):
		for a in self.atoms: a.residue = self
	def delrefs(self):
		for a in self.atoms: a.residue = None
	def __getitem__(self,name):
		for a in self.atoms:
			if ( a.name == name ): return a
		raise "No such atom."
	def rotate(self,angle,units=angledefault):
		t = Trans(center=self['CA'],axis=self['CB'],angle=angle,units=units)
		for a in self.atoms:
			if a.name not in backbone: t(a)
	def phipsi(self,units=angledefault):
		try: phi = angle(self.prev['C'],self['N'],self['CA'],self['C'],units)
		except: phi = None
		try: psi = angle(self['N'],self['CA'],self['C'],self.next['N'],units)
		except: psi = None
		return (phi,psi)
	def __repr__(self):
		return '< Residue '+self.name+' with '\
			+`len(self.atoms)`+' atoms >'

class ASel(AtomGroup):
	"""A group of atoms generated from a filter function.

Methods:
   s = ASel(base,func)

See also: RSel
"""
	def __init__(self,base,func):
		AtomGroup.__init__(self)
		self.atoms = filter(func,base.atoms)

class ResidueGroup(AtomGroup):
	"""A group of residues.

Data: residues

Methods:
   g = ResidueGroup()
   g.buildlists() - generate atoms from residues
   g.phipsi([units]) - returns list of all (phi,psi)
   g.rsel(func) - returns residue selection based on filter function

See also: RSel
"""
	def __init__(self):
		AtomGroup.__init__(self)
		self.residues = []
	def buildlists(self):
		self.atoms[:] = []
		for r in self.residues:
			for a in r.atoms: self.atoms.append(a)
	def phipsi(self,units=angledefault):
		return map(lambda r,u=units: r.phipsi(u), self.residues)
	def rsel(self,func):
		return RSel(self,func)
	def __repr__(self):
		return '< '+self.__class__.__name__+' with '\
			+`len(self.residues)`+' residues, and '\
			+`len(self.atoms)`+' atoms >'

class RSel(ResidueGroup):
	"""A group of residues generated from a filter function.

Methods:
   s = RSel(base,func)

See also: ASel
"""
	def __init__(self,base,func):
		ResidueGroup.__init__(self)
		self.residues = filter(func,base.residues)
		self.buildlists()

class Segment(ResidueGroup):
	"""A group of residues with extra information.

Data: name, molecule

Methods:
   s = Segment()
   s.buildrefs() - assigns segment for residues (done by Molecule)
   s.delrefs() - removes references to allow deletion (done by Molecule)

See also: Residue, Molecule
"""
	def __init__(self):
		ResidueGroup.__init__(self)
		self.name = '???'
		self.id = 0
		molecule = None
	def buildrefs(self):
		for r in self.residues:
			r.segment = self
			r.buildrefs()
		for i in range(1,len(self.residues)):
			self.residues[i-1].next = self.residues[i]
			self.residues[i].prev = self.residues[i-1]
	def delrefs(self):
		for r in self.residues:
			r.segment = None
			r.delrefs()
		for i in range(1,len(self.residues)):
			self.residues[i-1].next = None
			self.residues[i].prev = None
	def __repr__(self):
		return '< Segment '+self.name+' with '\
			+`len(self.residues)`+' residues, and '\
			+`len(self.atoms)`+' atoms >'

class SegmentGroup(ResidueGroup):
	"""A group of segments.

Data: segments

Methods:
   g = SegmentGroup()
   g.buildlists() - generate residues from segments
"""
	def __init__(self):
		ResidueGroup.__init__(self)
		self.segments = []
	def buildlists(self):
		self.residues[:] = []
		for s in self.segments:
			s.buildlists()
			for r in s.residues: self.residues.append(r)
		ResidueGroup.buildlists(self)
	def __repr__(self):
		return '< '+self.__class__.__name__+' with '\
			+`len(self.segments)`+' segments, '\
			+`len(self.residues)`+' residues, and '\
			+`len(self.atoms)`+' atoms >'

def _sround(x,n):
	raw = str(round(x,n))
	if string.find(raw,'.') == -1 :
		raw = raw + '.'
	while len(raw) - string.find(raw,'.') <= n :
		raw = raw + '0'
	return raw

def _Ftoi(s):
	return string.atoi(string.strip(s))

def _Ftof(s):
	return string.atof(string.strip(s))

class Molecule(SegmentGroup):
	"""Complete interface for pdb/psf molecule files.

Data: pdbfile, psffile, pdbremarks, psfremarks
      _bonds, _angles, _dihedrals, _impropers, _donors, _acceptors
      optionally: bonds, angles, dihedrals, impropers, donors, acceptors

Methods:
   m = Molecule([pdb],[psf]) - read molecule from file(s)
   m.buildrefs() - assigns molecule for segments (done on creation)
   m.delrefs() - removes references to allow deletion (must be done by user)
   m.buildstructure() - adds structure lists to molecule and atoms
   m.writepdb([file]) - write pdb to file (pdbfile by default)
   m.view() - launch a pdb viewer with the current coordinates

See also: Segment, pdbdisplayfunction
"""
	def __init__(self,pdb=None,psf=None):
		SegmentGroup.__init__(self)
		self.pdbfile = pdb
		self.psffile = psf
		self.pdbremarks = []
		self.psfremarks = []
		self._bonds = []
		self._angles = []
		self._dihedrals = []
		self._impropers = []
		self._donors = []
		self._acceptors = []
		pdb = self.pdbfile
		psf = self.psffile
		if not ( pdb or psf ):
			raise "No data files specified."
		if pdb:
			pdbf = open(self.pdbfile,'r')
			pdbrec = pdbf.readline()
			while len(pdbrec) and pdbrec[0:6] == 'REMARK':
				self.pdbremarks.append(string.strip(pdbrec))
				print self.pdbremarks[-1]
				pdbrec = pdbf.readline()
		if psf:
			psff = open(self.psffile,'r')
			psfline = psff.readline()
			psfrec = string.split(psfline)
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1] == '!NTITLE'):
				psfline = psff.readline()
				psfrec = string.split(psfline)
			nrecs = string.atoi(psfrec[0])
			for i in range(0,nrecs):
				psfrec = psff.readline()
				self.psfremarks.append(string.strip(psfrec))
				print self.psfremarks[-1]
			psfline = psff.readline()
			psfrec = string.split(psfline)
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1] == '!NATOM'):
				psfline = psff.readline()
				psfrec = string.split(psfline)
			nrecs = string.atoi(psfrec[0])
		moretogo = 0
		if pdb:
			if len(pdbrec) and pdbrec[0:6] in ('ATOM  ','HETATM'): moretogo = 1
		if psf:
			psfrec = string.split(psff.readline())
			if nrecs > len(self.atoms): moretogo = 1
		curseg = None
		curres = None
		numread = 0
		while moretogo:
			moretogo = 0
			if psf:
				if (not curseg) or psfrec[1] != curseg.name:
					curseg = Segment()
					self.segments.append(curseg)
					curseg.name = psfrec[1]
					curseg.id = len(self.segments)
				if (not curres) or string.atoi(psfrec[2]) != curres.id:
					curres = Residue()
					curseg.residues.append(curres)
					curres.id = string.atoi(psfrec[2])
					curres.name = psfrec[3]
					curres.type = curres.name
			else:
				if (not curseg) or string.strip(pdbrec[67:]) != curseg.name:
					curseg = Segment()
					self.segments.append(curseg)
					curseg.name = string.strip(pdbrec[67:])
					curseg.id = len(self.segments)
				if (not curres) or _Ftoi(pdbrec[22:26]) != curres.id:
					curres = Residue()
					curseg.residues.append(curres)
					curres.id = _Ftoi(pdbrec[22:26])
					curres.name = string.strip(pdbrec[17:21])
					curres.type = curres.name
			curatom = Atom()
			curres.atoms.append(curatom)
			numread = numread + 1
			if pdb:
				curatom.name = string.strip(pdbrec[12:16])
				curatom.type = curatom.name
				curatom.id = _Ftoi(pdbrec[6:11])
				curatom.x = _Ftof(pdbrec[30:38])
				curatom.y = _Ftof(pdbrec[38:46])
				curatom.z = _Ftof(pdbrec[46:54])
				curatom.q = _Ftof(pdbrec[54:60])
				curatom.b = _Ftof(pdbrec[60:66])
				pdbrec = pdbf.readline()
				if len(pdbrec) and pdbrec[0:6] in ('ATOM  ','HETATM'): moretogo = 1
			if psf:
				curatom.name = psfrec[4]
				curatom.type = psfrec[5]
				curatom.id = string.atoi(psfrec[0])
				curatom.mass = string.atof(psfrec[7])
				curatom.charge = string.atof(psfrec[6])
				psfrec = string.split(psff.readline())
				if nrecs > numread: moretogo = 1
		if pdb: pdbf.close()
		if psf:
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:6] == '!NBOND'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._bonds.append((string.atoi(psfrec[0]),string.atoi(psfrec[1])))
					nrecs = nrecs - 1
					psfrec = psfrec[2:]
			psfrec = string.split(psff.readline())
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:7] == '!NTHETA'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._angles.append((string.atoi(psfrec[0]),
						string.atoi(psfrec[1]),string.atoi(psfrec[2])))
					nrecs = nrecs - 1
					psfrec = psfrec[3:]
			psfrec = string.split(psff.readline())
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:5] == '!NPHI'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._dihedrals.append((string.atoi(psfrec[0]),string.atoi(psfrec[1]),
						string.atoi(psfrec[2]),string.atoi(psfrec[3])))
					nrecs = nrecs - 1
					psfrec = psfrec[4:]
			psfrec = string.split(psff.readline())
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:7] == '!NIMPHI'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._impropers.append((string.atoi(psfrec[0]),string.atoi(psfrec[1]),
						string.atoi(psfrec[2]),string.atoi(psfrec[3])))
					nrecs = nrecs - 1
					psfrec = psfrec[4:]
			psfrec = string.split(psff.readline())
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:5] == '!NDON'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._donors.append((string.atoi(psfrec[0]),string.atoi(psfrec[1])))
					nrecs = nrecs - 1
					psfrec = psfrec[2:]
			psfrec = string.split(psff.readline())
			while len(psfline) and not (len(psfrec) > 1 and psfrec[1][0:5] == '!NACC'):
				psfrec = string.split(psff.readline())
			nrecs = string.atoi(psfrec[0])
			while ( nrecs ):
				psfrec = string.split(psff.readline())
				while ( len(psfrec) ):
					self._acceptors.append((string.atoi(psfrec[0]),string.atoi(psfrec[1])))
					nrecs = nrecs - 1
					psfrec = psfrec[2:]
			psff.close()
		self.buildlists()
		self.buildrefs()
	def buildrefs(self):
		for s in self.segments:
			s.molecule = self
			s.buildrefs()
	def delrefs(self):
		for s in self.segments:
			s.molecule = None
			s.delrefs()
	def buildstructure(self):
		self.bonds = []
		self.angles = []
		self.dihedrals = []
		self.impropers = []
		self.donors = []
		self.acceptors = []
		for a in self.atoms:
			a.bonds = []
			a.angles = []
			a.dihedrals = []
			a.impropers = []
			a.donors = []
			a.acceptors = []
		def mapfunc(id,list=self.atoms): 
			a = list[id-1]
			if ( a.id != id ):
				raise "Atom list indexes corrupted."
			return a
		def mapatom(t,func=mapfunc):
			return tuple(map(func,t))
		for b in self._bonds:
			s = mapatom(b)
			self.bonds.append(s)
			for a in s: a.bonds.append(s)
		for b in self._angles:
			s = mapatom(b)
			self.angles.append(s)
			for a in s: a.angles.append(s)
		for b in self._dihedrals:
			s = mapatom(b)
			self.dihedrals.append(s)
			for a in s: a.dihedrals.append(s)
		for b in self._impropers:
			s = mapatom(b)
			self.impropers.append(s)
			for a in s: a.impropers.append(s)
		for b in self._donors:
			s = mapatom(b)
			self.donors.append(s)
			for a in s: a.donors.append(s)
		for b in self._acceptors:
			s = mapatom(b)
			self.acceptors.append(s)
			for a in s: a.acceptors.append(s)
	def writepdb(self,pdbfile=None):
		if not pdbfile:
			pdbfile = self.pdbfile
		if not pdbfile:
			raise "No pdb file specified."
		f = open(pdbfile,'w')
		for r in self.pdbremarks:
			if r[0:6] == 'REMARK' :
				f.write(r+'\n')
			else:
				f.write('REMARK '+r+'\n')
		for a in self.atoms:
			f.write('ATOM  ')
			f.write(string.rjust(str(a.id),5)+' ')
			if len(a.name) > 3:
				f.write(string.ljust(a.name,4)+' ')
			else:
				f.write(' '+string.ljust(a.name,3)+' ')
			f.write(string.ljust(a.residue.name,4))
			f.write(' '+string.rjust(str(a.residue.id),4)+'    ')
			f.write(string.rjust(_sround(a.x,3),8))
			f.write(string.rjust(_sround(a.y,3),8))
			f.write(string.rjust(_sround(a.z,3),8))
			f.write(string.rjust(_sround(a.q,2),6))
			f.write(string.rjust(_sround(a.b,2),6))
			f.write(string.rjust(a.residue.segment.name,10))
			f.write('\n')
		f.write('END\n')
		f.close()
	def view(self):
		d = pdbview()
		self.writepdb(d.load())
		d.show()
		d.free()
	def __repr__(self):
		return '< Molecule with '\
			+`len(self.segments)`+' segments, '\
			+`len(self.residues)`+' residues, and '\
			+`len(self.atoms)`+' atoms >'


