/*
 * Copyright (C) 2004-2006 by David J. Hardy.  All rights reserved.
 *
 * setup.c
 *
 * Setup and cleanup Force object.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "force/intdefn.h"
#undef DEBUG_WATCH
#include "debug/debug.h"


int force_initialize(Force *f, ForceParam *fprm, ForceDomain *fdom,
    ForceSelect *fsel, const MD_Dvec initpos[])
{
  const int32 natoms = fprm->atom_len;
  const int32 forcetypes = fprm->forcetypes;

  /* reset all memory for force object */
  memset(f, 0, sizeof(Force));

  /* check validity of param object? */
  f->param = fprm;

  if (force_setup_selection(f, fsel)) {
    ERROR("force_setup_selection()");
    return FORCE_FAIL;
  }

  /* allocate and reset poswrap and trpos arrays */
  f->poswrap = (MD_Dvec *) calloc(natoms, sizeof(MD_Dvec));
  if (f->poswrap == NULL) return FORCE_FAIL;
  f->trpos = (MD_Dvec *) calloc(natoms, sizeof(MD_Dvec));
  if (f->trpos == NULL) return FORCE_FAIL;

  VEC(initpos[0]);
  if (force_setup_domain(f, fdom, initpos, f->abset_sel, f->abset_sel_len)) {
    ERROR("force_setup_domain()");
    return FORCE_FAIL;
  }

  if (forcetypes & FORCE_BRES) {
    if (force_setup_bres(f)) {
      ERROR("force_setup_bres()");
      return FORCE_FAIL;
    }
  }

  if (forcetypes & FORCE_PAIRWISE) {
    const int32 is_elec = (forcetypes & FORCE_ELEC) != 0;
    const int32 is_vdw = (forcetypes & FORCE_VDW) != 0;
    const int32 is_gridcells = (fprm->elecopts & FORCE_ELEC_GRIDCELLS)
      || (fprm->vdwopts & FORCE_VDW_GRIDCELLS);
    const int32 is_pairlists = (fprm->elecopts & FORCE_ELEC_PAIRLISTS)
      || (fprm->vdwopts & FORCE_VDW_PAIRLISTS);

    if (force_setup_exclusions(f)) {
      ERROR("force_setup_exclusions()");
      return FORCE_FAIL;
    }

    if (is_elec) {
      if (fprm->elec_cutoff == 0.0) fprm->elec_cutoff = fprm->cutoff;
      f->elec_const = fprm->elec_const / fprm->dielectric;
      f->elec_cutoff2 = fprm->elec_cutoff * fprm->elec_cutoff;
      if (f->elec_cutoff2 == 0.0) return FORCE_FAIL;
      f->inv_elec_cutoff2 = 1.0 / f->elec_cutoff2;
      f->ewald_grad_coef = 2.0 / sqrt(M_PI) * fprm->ewald_coef;
    }

    if (is_vdw) {
      if (fprm->vdw_cutoff == 0.0) fprm->vdw_cutoff = fprm->cutoff;
      f->vdw_cutoff2 = fprm->vdw_cutoff * fprm->vdw_cutoff;
      if (fprm->vdwopts & (FORCE_VDW_SWITCHED | FORCE_VDW_SWITCHBUCK
            | FORCE_VDW_SWITCHBUCKND | FORCE_VDW_SWITCHBUCKSAFE)) {
        if (fprm->switchdist >= fprm->vdw_cutoff) return FORCE_FAIL;
        f->switch_dist2 = fprm->switchdist * fprm->switchdist;
        f->inv_denom_switch = 1.0 /
          ((f->vdw_cutoff2 - f->switch_dist2) *
           (f->vdw_cutoff2 - f->switch_dist2) *
           (f->vdw_cutoff2 - f->switch_dist2));
      }
      if (force_setup_vdwparams(f)) {
        ERROR("force_setup_vdwparams()");
        return FORCE_FAIL;
      }
    }

    /* store the largest of elec_cutoff and vdw_cutoff into cutoff */
    if (fprm->cutoff < fprm->elec_cutoff) fprm->cutoff = fprm->elec_cutoff;
    if (fprm->cutoff < fprm->vdw_cutoff) fprm->cutoff = fprm->vdw_cutoff;

    if (is_gridcells || is_pairlists) {
      double mincellsize;
      double outer_cutoff;

      if (fprm->delta_dis < 0.0) {
        ERROR("delta_dis is negative");
        return FORCE_FAIL;
      }

      f->delta_dis2 = fprm->delta_dis * fprm->delta_dis;
      outer_cutoff = fprm->cutoff + 2*fprm->delta_dis;
      mincellsize = fprm->cutoff_ratio * outer_cutoff + fprm->cell_margin;

      FLT(fprm->delta_dis);
      FLT(outer_cutoff);
      FLT(fprm->cutoff_ratio);
      FLT(mincellsize);
      INT(fprm->shell_width);
      FLT(fprm->shell_width * mincellsize);

      if (fprm->shell_width < 1 || fprm->shell_width > FORCE_SHELLMAX) {
        ERROR("shell_width is out-of-range");
        return FORCE_FAIL;
      }
      else if (outer_cutoff > fprm->shell_width * mincellsize) {
        ERROR("cutoff extends beyond the designated shell of grid cells");
        return FORCE_FAIL;
      }

      f->outer_cutoff2 = outer_cutoff * outer_cutoff;
     
      /* must setup gridcells regardless */
      if (force_setup_gridcells(f, mincellsize, fprm->shell_width)) {
        ERROR("force_setup_gridcells()");
        return FORCE_FAIL;
      }
     
#if 0
      if (is_pairlists) {
        /* index pairlists across the smaller selection of atoms */
        const int32 aset_sel_len = f->select->aset_sel_len;
        const int32 bset_sel_len = f->select->bset_sel_len;
        const int32 *sel = (aset_sel_len < bset_sel_len ?
            f->select->aset_sel : f->select->bset_sel);
        const int32 sel_len = (aset_sel_len < bset_sel_len ?
            aset_sel_len : bset_sel_len);

        f->pairlist_set_id = (aset_sel_len < bset_sel_len ?
            FORCE_INDEX_ASET : FORCE_INDEX_BSET);
       
        if (force_setup_pairlists(f, outer_cutoff, sel, sel_len)) {
          ERROR("force_setup_pairlists()");
          return FORCE_FAIL;
        }
      }
#endif
      if (is_pairlists && force_setup_pairlists(f, outer_cutoff,
            f->abset_sel, f->abset_sel_len)) {
        ERROR("force_setup_pairlists()");
        return FORCE_FAIL;
      }

    }

  }

  return 0;
}


#if 0
  int32 is_elec, is_vdw;
  int32 n;

  ASSERT(f != NULL);
  if (p == NULL || d == NULL) return FORCE_FAIL;

  /* check validity of options */
  if ((p->forcetypes & ~FORCE_MASK_FTYPES)
      || (p->sepforcetypes & ~FORCE_MASK_STYPES)
      || (p->energytypes & ~FORCE_MASK_ETYPES)
      || (p->exclpolicy <= 0 || p->exclpolicy >= FORCE_MARKER_EXCL)
      || (p->elecopts <= 0 || p->elecopts >= FORCE_MARKER_ELEC)
      || (p->vdwopts <= 0 || p->vdwopts >= FORCE_MARKER_VDW)
      || (p->bdresopts <= 0 || p->bdresopts >= FORCE_MARKER_BDRES)) {
    TEXT("check validity failed");
    return FORCE_FAIL;
  }

  /* check for positive number of atoms */
  if (p->atom_len <= 0) return FORCE_FAIL;

  /* check validity of nonbonded params */
  if (p->cutoff < 0.0 || p->elec_cutoff < 0.0 || p->vdw_cutoff < 0.0
      || p->switchdist < 0.0 || p->elec_const < 0.0 || p->dielectric < 0.0) {
    return FORCE_FAIL;
  }

  is_elec = (p->forcetypes & FORCE_ELEC)
    || (p->sepforcetypes & FORCE_ELEC)
    || (p->energytypes & (FORCE_ELEC | FORCE_EPOT));

  /* set FORCE_ELEC_NONE iff no electrostatics evaluation */
  if ((is_elec && p->elecopts == FORCE_ELEC_NONE)
      || (!is_elec && p->elecopts != FORCE_ELEC_NONE)) {
    return FORCE_FAIL;
  }

  if (is_elec && p->elec_cutoff == 0.0) {
    if (p->elecopts != FORCE_ELEC_DIRECT
        && p->elecopts != FORCE_ELEC_EXCL_TRUNC
        && p->elecopts != FORCE_ELEC_EXCL_SMOOTH) {
      if (p->cutoff == 0.0) {
        return FORCE_FAIL;
      }
      p->elec_cutoff = p->cutoff;
    }
  }

  is_vdw = (p->forcetypes & FORCE_VDW)
    || (p->sepforcetypes & FORCE_VDW)
    || (p->energytypes & FORCE_VDW);

  /* set FORCE_VDW_NONE iff no van der Waals evaluation */
  if ((is_vdw && p->vdwopts == FORCE_VDW_NONE)
      || (!is_vdw && p->vdwopts != FORCE_VDW_NONE)) {
    return FORCE_FAIL;
  }

  if (is_vdw && p->vdw_cutoff == 0.0) {
    if (p->vdwopts != FORCE_VDW_DIRECT
        && p->vdwopts != FORCE_VDW_EXCL_TRUNC
        && p->vdwopts != FORCE_VDW_EXCL_SWITCH) {
      if (p->cutoff == 0.0) {
        return FORCE_FAIL;
      }
      p->vdw_cutoff = p->cutoff;
    }
  }

  if (is_vdw && p->atomprm_len <= 0) {
    return FORCE_FAIL;
  }

  if (is_elec || is_vdw) {
    if (p->elec_cutoff > p->vdw_cutoff) p->cutoff = p->elec_cutoff;
    else p->cutoff = p->vdw_cutoff;
    ASSERT(p->cutoff > 0.0 
        || p->elecopts == FORCE_ELEC_DIRECT
        || p->elecopts == FORCE_ELEC_EXCL_TRUNC
        || p->elecopts == FORCE_ELEC_EXCL_SMOOTH
        || p->vdwopts == FORCE_VDW_DIRECT
        || p->vdwopts == FORCE_VDW_EXCL_TRUNC
        || p->vdwopts == FORCE_VDW_EXCL_SWITCH);
  }

  if (is_elec && (p->elec_const <= 0.0 || p->dielectric < 1.0)) {
    return FORCE_FAIL;
  }

  if (p->vdwopts == FORCE_VDW_SWITCH
      && (p->switchdist == 0.0 || p->switchdist >= p->vdw_cutoff)) {
    return FORCE_FAIL;
  }

  /* check validity of boundary constraint constants */
  if (p->radius1 < 0.0 || p->radius2 < 0.0
      || p->length1 < 0.0 || p->length2 < 0.0
      || p->exp1 < 0 || p->exp2 < 0) {
    return FORCE_FAIL;
  }

  /* save pointers */
  f->param = p;

  /* setup bonded data structures */
  if (((p->forcetypes | p->sepforcetypes | p->energytypes) & FORCE_BONDED)
      && force_setup_bonded(f)) {
    return FORCE_FAIL;
  }

  /* setup nonbonded data structures */
  if (is_elec || is_vdw) {
    f->fnonb = force_nonbonded_create(p, d);
    if (f->fnonb == NULL) return FORCE_FAIL;
  }

  /* setup boundary constraints */
  if (((p->forcetypes | p->sepforcetypes | p->energytypes) & FORCE_BDRES)
      && force_setup_boundary(f)) {
    return FORCE_FAIL;
  }

  /* allocate alternate storage buffer for force and atomic potentials */
  /* will we always need these here? */
  f->alt_f_buffer = (MD_Dvec *) calloc(f->param->atom_len, sizeof(MD_Dvec));
  if (f->alt_f_buffer == NULL) return FORCE_FAIL;
  f->alt_u_buffer = (double *) calloc(f->param->atom_len, sizeof(double));
  if (f->alt_u_buffer == NULL) return FORCE_FAIL;

  /* allocate and populate index identity mapping array */
  f->all_atom_index = (int32 *) calloc(f->param->atom_len, sizeof(int32));
  if (f->all_atom_index == NULL) return FORCE_FAIL;
  for (n = 0;  n < f->param->atom_len;  n++) {
    f->all_atom_index[n] = n;
  }
#endif


void force_cleanup(Force *f)
{
  const int32 forcetypes = f->param->forcetypes;

  free(f->poswrap);
  free(f->trpos);
  force_cleanup_selection(f);
  force_cleanup_domain(f);
  if (forcetypes & FORCE_BRES) force_cleanup_bres(f);
  if (forcetypes & FORCE_PAIRWISE) {
    const int32 is_gridcells = (f->param->elecopts & FORCE_ELEC_GRIDCELLS)
      || (f->param->vdwopts & FORCE_VDW_GRIDCELLS);
    const int32 is_pairlists = (f->param->elecopts & FORCE_ELEC_PAIRLISTS)
      || (f->param->vdwopts & FORCE_VDW_PAIRLISTS);
    force_cleanup_exclusions(f);
    if (forcetypes & FORCE_VDW) force_cleanup_vdwparams(f);
    if (is_gridcells || is_pairlists) force_cleanup_gridcells(f);
    if (is_pairlists) force_cleanup_pairlists(f);
  }
}


int force_setup_selection(Force *f, ForceSelect *s)
{
  ForceParam *p = f->param;

  int32 *map;          /* temp mapping buffer that points to mapnb */

  int32 *sel = NULL;   /* points to an atom selection */
  int32 *asel = NULL;  /* points to the atom pairwise set A selection */
  int32 *bsel = NULL;  /* points to the atom pairwise set B selection */

  int32 sel_len;       /* length of sel array */
  int32 asel_len;      /* length of asel array */
  int32 bsel_len;      /* length of bsel array */

  const int32 natoms = p->atom_len;  /* number of atoms in system */
  int32 idmaplen;      /* length of ID map (might be > natoms) */

  int32 cnt;           /* counts number of atoms involved in bond selection */
  int32 i, j, k, m;    /* loop index variables */

  /* retain pointer to selection object */
  f->select = s;

  /* the idmap needs to be as long as the longest atom or bond array */
  idmaplen = natoms;
  if (idmaplen < p->bond_len) idmaplen = p->bond_len;
  if (idmaplen < p->angle_len) idmaplen = p->angle_len;
  if (idmaplen < p->dihed_len) idmaplen = p->dihed_len;
  if (idmaplen < p->impr_len) idmaplen = p->impr_len;

  /* allocate and initialize the idmap */
  f->idmap = (int32 *) malloc(idmaplen * sizeof(int32));
  if (f->idmap == NULL) return FORCE_FAIL;
  for (i = 0;  i < idmaplen;  i++) {
    f->idmap[i] = i;
  }

  /* buffer for potentials needs to be as long as idmap */
  f->e_buffer = (double *) calloc(idmaplen, sizeof(double));
  if (f->e_buffer == NULL) return FORCE_FAIL;

  /* buffer for forces needs to be as long as atom array */
  f->f_buffer = (MD_Dvec *) calloc(natoms, sizeof(MD_Dvec));
  if (f->f_buffer == NULL) return FORCE_FAIL;

  f->mapnb = (int32 *) calloc(natoms, sizeof(int32));
  if (f->mapnb == NULL) return FORCE_FAIL;

  /* use this as buffer space to tally atoms involved in some force */
  map = f->mapnb;

  /* build list of atoms involved in selected spring bonds */
  sel = (s ? s->bond_sel : f->idmap);
  sel_len = (s ? s->bond_sel_len : p->bond_len);
  cnt = 0;
  for (k = 0;  k < sel_len;  k++) {
    j = sel[k];
    for (m = 0;  m < 2;  m++) {
      i = p->bond[j].atom[m];
      if ((map[i] & FORCE_INDEX_BOND) == 0) {
        map[i] |= FORCE_INDEX_BOND;
        cnt++;
      }
    }
  }
  if (cnt == natoms) {
    f->atom_bond_sel = f->idmap;
    f->atom_bond_sel_len = natoms;
  }
  else if (cnt != 0) {
    f->atom_bond_sel = (int32 *) malloc(cnt * sizeof(int32));
    if (f->atom_bond_sel == NULL) return FORCE_FAIL;
    f->is_alloc_sel |= FORCE_INDEX_BOND;
    f->atom_bond_sel_len = cnt;
    cnt = 0;
    for (i = 0;  i < natoms;  i++) {
      if (map[i] & FORCE_INDEX_BOND) {
        f->atom_bond_sel[cnt] = i;
        cnt++;
      }
    }
    ASSERT(cnt == f->atom_bond_sel_len);
  }

  /* build list of atoms involved in selected angles */
  sel = (s ? s->angle_sel : f->idmap);
  sel_len = (s ? s->angle_sel_len : p->angle_len);
  cnt = 0;
  for (k = 0;  k < sel_len;  k++) {
    j = sel[k];
    for (m = 0;  m < 3;  m++) {
      i = p->angle[j].atom[m];
      if ((map[i] & FORCE_INDEX_ANGLE) == 0) {
        map[i] |= FORCE_INDEX_ANGLE;
        cnt++;
      }
    }
  }
  if (cnt == natoms) {
    f->atom_angle_sel = f->idmap;
    f->atom_angle_sel_len = natoms;
  }
  else if (cnt != 0) {
    f->atom_angle_sel = (int32 *) malloc(cnt * sizeof(int32));
    if (f->atom_angle_sel == NULL) return FORCE_FAIL;
    f->is_alloc_sel |= FORCE_INDEX_ANGLE;
    f->atom_angle_sel_len = cnt;
    cnt = 0;
    for (i = 0;  i < natoms;  i++) {
      if (map[i] & FORCE_INDEX_ANGLE) {
        f->atom_angle_sel[cnt] = i;
        cnt++;
      }
    }
    ASSERT(cnt == f->atom_angle_sel_len);
  }

  /* build list of atoms involved in selected dihedrals */
  sel = (s ? s->dihed_sel : f->idmap);
  sel_len = (s ? s->dihed_sel_len : p->dihed_len);
  cnt = 0;
  for (k = 0;  k < sel_len;  k++) {
    j = sel[k];
    for (m = 0;  m < 4;  m++) {
      i = p->dihed[j].atom[m];
      if ((map[i] & FORCE_INDEX_DIHED) == 0) {
        map[i] |= FORCE_INDEX_DIHED;
        cnt++;
      }
    }
  }
  if (cnt == natoms) {
    f->atom_dihed_sel = f->idmap;
    f->atom_dihed_sel_len = natoms;
  }
  else if (cnt != 0) {
    f->atom_dihed_sel = (int32 *) malloc(cnt * sizeof(int32));
    if (f->atom_dihed_sel == NULL) return FORCE_FAIL;
    f->is_alloc_sel |= FORCE_INDEX_DIHED;
    f->atom_dihed_sel_len = cnt;
    cnt = 0;
    for (i = 0;  i < natoms;  i++) {
      if (map[i] & FORCE_INDEX_DIHED) {
        f->atom_dihed_sel[cnt] = i;
        cnt++;
      }
    }
    ASSERT(cnt == f->atom_dihed_sel_len);
  }

  /* build list of atoms involved in selected impropers */
  sel = (s ? s->impr_sel : f->idmap);
  sel_len = (s ? s->impr_sel_len : p->impr_len);
  cnt = 0;
  for (k = 0;  k < sel_len;  k++) {
    j = sel[k];
    for (m = 0;  m < 4;  m++) {
      i = p->impr[j].atom[m];
      if ((map[i] & FORCE_INDEX_IMPR) == 0) {
        map[i] |= FORCE_INDEX_IMPR;
        cnt++;
      }
    }
  }
  if (cnt == natoms) {
    f->atom_impr_sel = f->idmap;
    f->atom_impr_sel_len = natoms;
  }
  else if (cnt != 0) {
    f->atom_impr_sel = (int32 *) malloc(cnt * sizeof(int32));
    if (f->atom_impr_sel == NULL) return FORCE_FAIL;
    f->is_alloc_sel |= FORCE_INDEX_IMPR;
    f->atom_impr_sel_len = cnt;
    cnt = 0;
    for (i = 0;  i < natoms;  i++) {
      if (map[i] & FORCE_INDEX_IMPR) {
        f->atom_impr_sel[cnt] = i;
        cnt++;
      }
    }
    ASSERT(cnt == f->atom_impr_sel_len);
  }

  /* build list of atoms involved in combined A and B sets */
  asel = (s ? s->aset_sel : NULL);
  asel_len = (s ? s->aset_sel_len : 0);
  bsel = (s ? s->bset_sel : NULL);
  bsel_len = (s ? s->bset_sel_len : 0);
  if (asel == bsel) {
    f->abset_sel = f->idmap;
    f->abset_sel_len = natoms;
    for (i = 0;  i < natoms;  i++) {
      map[i] |= (FORCE_INDEX_ASET | FORCE_INDEX_BSET);
    }
  }
  else {
    if (asel_len + bsel_len == natoms) {
      f->abset_sel = f->idmap;
      f->abset_sel_len = natoms;
    }
    else {
      cnt = asel_len + bsel_len;
      f->abset_sel = (int32 *) malloc(cnt * sizeof(int32));
      if (f->abset_sel == NULL) return FORCE_FAIL;
      f->is_alloc_sel |= (FORCE_INDEX_ASET | FORCE_INDEX_BSET);
      f->abset_sel_len = cnt;
    }
    cnt = 0;
    /* assume that A and B set indices are in sorted order */
    for (i = 0, j = 0;  i < asel_len && j < bsel_len; ) {
      if (asel[i] < bsel[j]) {
        k = asel[i];
        i++;
        map[k] |= FORCE_INDEX_ASET;
      }
      else {
        k = bsel[j];
        j++;
        map[k] |= FORCE_INDEX_BSET;
      }
      f->abset_sel[cnt] = k;
      cnt++;
    }
    for ( ;  i < asel_len;  i++) {
      k = asel[i];
      map[k] |= FORCE_INDEX_ASET;
      f->abset_sel[cnt] = k;
      cnt++;
    }
    for ( ;  j < bsel_len;  j++) {
      k = bsel[j];
      map[k] |= FORCE_INDEX_BSET;
      f->abset_sel[cnt] = k;
      cnt++;
    }
    ASSERT(cnt == f->abset_sel_len);
  }

  /* tally which atoms are involved in boundary restraints */
  sel = (s ? s->bres_sel : f->idmap);
  sel_len = (s ? s->bres_sel_len : natoms);
  cnt = 0;
  for (k = 0;  k < sel_len;  k++) {
    j = sel[k];
    map[j] |= FORCE_INDEX_BRES;
    cnt++;
  }

  /* use recorded tally to build list of atoms accumulated to total */
  cnt = 0;
  for (i = 0;  i < natoms;  i++) {
    if (map[i]) cnt++;
  }
  if (cnt == natoms) {
    f->total_sel = f->idmap;
    f->total_sel_len = natoms;
  }
  else if (cnt != 0) {
    f->total_sel = (int32 *) malloc(cnt * sizeof(int32));
    if (f->total_sel == NULL) return FORCE_FAIL;
    f->is_alloc_sel |= FORCE_INDEX_TOTAL;
    f->total_sel_len = cnt;
    cnt = 0;
    for (i = 0;  i < natoms;  i++) {
      if (map[i]) {
        f->total_sel[cnt] = i;
        cnt++;
      }
    }
    ASSERT(cnt == f->total_sel_len);
  }

  /* reset tally to identify only A and B set membership */
  for (i = 0;  i < natoms;  i++) {
    map[i] &= (FORCE_INDEX_ASET | FORCE_INDEX_BSET);
  }

  return 0;
}


void force_cleanup_selection(Force *f)
{
  /* free memory space for alternate storage buffers */
  free(f->e_buffer);
  free(f->f_buffer);

  /* free memory space for ID map and pairwise set membership */
  free(f->idmap);
  free(f->mapnb);

  /* check flags to see if these arrays were allocated */
  if (f->is_alloc_sel & FORCE_INDEX_BOND) free(f->atom_bond_sel);
  if (f->is_alloc_sel & FORCE_INDEX_ANGLE) free(f->atom_angle_sel);
  if (f->is_alloc_sel & FORCE_INDEX_DIHED) free(f->atom_dihed_sel);
  if (f->is_alloc_sel & FORCE_INDEX_IMPR) free(f->atom_impr_sel);
  if (f->is_alloc_sel & (FORCE_INDEX_ASET | FORCE_INDEX_BSET)) {
    free(f->abset_sel);
  }
  if (f->is_alloc_sel & FORCE_INDEX_TOTAL) free(f->total_sel);
}


/*
 * assume poswrap and trpos are allocated and cleared
 */
int force_setup_domain(Force *fobj, ForceDomain *fdom,
    const MD_Dvec pos[], const int32 sel[], int32 sel_len)
{
  MD_Dvec *wrap = fobj->poswrap;
  MD_Dvec *trpos = fobj->trpos;

  MD_Dvec v1 = fdom->v1;  /* cell basis vectors */
  MD_Dvec v2 = fdom->v2;
  MD_Dvec v3 = fdom->v3;

  MD_Dvec nv1 = { 0.0, 0.0, 0.0 };  /* normalized cell basis vectors */
  MD_Dvec nv2 = { 0.0, 0.0, 0.0 };
  MD_Dvec nv3 = { 0.0, 0.0, 0.0 };

  MD_Dvec ta1 = { 0.0, 0.0, 0.0 };  /* row vectors of transformation */
  MD_Dvec ta2 = { 0.0, 0.0, 0.0 };
  MD_Dvec ta3 = { 0.0, 0.0, 0.0 };

  double inv_len, inv_lv1, inv_lv2, inv_lv3;
  double min, max, len;

  MD_Dvec trlow;  /* transformed lower corner */
  MD_Dvec off;    /* accumulate periodic offsets */

  int32 i, j, k, indx;

  /* retain pointer to domain object */
  fobj->domain = fdom;

  /*
   * For nonperiodic directions, how far apart should the domain walls
   * be as a multiple of the max distance between particles?
   */
  fobj->domain_len_frac = DOMAIN_LEN_FRAC;

  /*
   * For nonperiodic system, how often should domain be re-centered
   * with respect to atomic coordinates?
   */
  fobj->max_steps = MAX_STEPS;

  /*
   * Use ForceDomain object to determine periodicity.
   * Fill in any zero basis vectors (to indicate nonperiodic "direction")
   * with normalized, well-conditioned basis vectors spanning R^3.
   * Nonzero vectors are assumed to be linearly independent.  (?)
   */
  fobj->is_periodic = FORCE_NONPERIODIC;
  fobj->is_periodic |= !(v1.x == 0.0 && v1.y == 0.0 && v1.z == 0.0) ?
    FORCE_X_PERIODIC : FORCE_NONPERIODIC;
  fobj->is_periodic |= !(v2.x == 0.0 && v2.y == 0.0 && v2.z == 0.0) ?
    FORCE_Y_PERIODIC : FORCE_NONPERIODIC;
  fobj->is_periodic |= !(v3.x == 0.0 && v3.y == 0.0 && v3.z == 0.0) ?
    FORCE_Z_PERIODIC : FORCE_NONPERIODIC;

  if (fobj->is_periodic & FORCE_X_PERIODIC) {
    /* normalize the v1 direction */
    inv_len = 1.0 / sqrt(v1.x*v1.x + v1.y*v1.y + v1.z*v1.z);
    nv1.x = v1.x * inv_len;
    nv1.y = v1.y * inv_len;
    nv1.z = v1.z * inv_len;
  }
  else {
    /*
     * need to choose a direction for nonperiodic domain
     * choose x-direction
     */
    nv1.x = 1.0;
  }

  if (fobj->is_periodic & FORCE_Y_PERIODIC) {
    /* normalize the v2 direction */
    inv_len = 1.0 / sqrt(v2.x*v2.x + v2.y*v2.y + v2.z*v2.z);
    nv2.x = v2.x * inv_len;
    nv2.y = v2.y * inv_len;
    nv2.z = v2.z * inv_len;
  }
  else {
    /*
     * need to choose a direction for nonperiodic domain
     * choose either (0,0,1) cross nv1 or (1,0,0) cross nv1,
     * depending on which vector that nv1 is more orthogonal to
     *
     * the closer (0,0,1) dot nv1 is to zero,
     * then the closer that these are to being orthogonal
     *
     * note that for nv1==(1,0,0) we get nv2==(0,1,0)
     */
    if (fabs(nv1.z) < 0.9) {
      /* set nv2 to be (0,0,1) cross nv1 */
      nv2.x = -nv1.y;
      nv2.y = nv1.x;
    }
    else {
      /* set nv2 to be (1,0,0) cross nv1 */
      nv2.y = nv1.z;
      nv2.z = -nv1.y;
    }
    /* normalize nv2 */
    inv_len = 1.0 / sqrt(nv2.x*nv2.x + nv2.y*nv2.y + nv2.z*nv2.z);
    nv2.x *= inv_len;
    nv2.y *= inv_len;
    nv2.z *= inv_len;
  }

  if (fobj->is_periodic & FORCE_Z_PERIODIC) {
    /* normalize the v3 direction */
    inv_len = 1.0 / sqrt(v3.x*v3.x + v3.y*v3.y + v3.z*v3.z);
    nv3.x = v3.x * inv_len;
    nv3.y = v3.y * inv_len;
    nv3.z = v3.z * inv_len;
  }
  else {
    /*
     * need to choose a direction for nonperiodic domain
     * choose nv1 cross nv2
     *
     * note that for nv1==(1,0,0) and nv2==(0,1,0) we get nv3==(0,0,1)
     */
    nv3.x = nv1.y * nv2.z - nv1.z * nv2.y;
    nv3.y = nv1.z * nv2.x - nv1.x * nv2.z;
    nv3.z = nv1.x * nv2.y - nv1.y * nv2.x;
    /* normalize nv3 */
    inv_len = 1.0 / sqrt(nv3.x*nv3.x + nv3.y*nv3.y + nv3.z*nv3.z);
    nv3.x *= inv_len;
    nv3.y *= inv_len;
    nv3.z *= inv_len;
  }

  /* save the normalized domain basis vectors */
  fobj->nv1 = nv1;
  fobj->nv2 = nv2;
  fobj->nv3 = nv3;
  VEC(nv1);
  VEC(nv2);
  VEC(nv3);

  /*
   * set is_orthogonal flag if basis vectors align with coordinate axes
   * (i.e. if transformation is diagonal)
   */
  fobj->is_orthogonal = (nv1.x != 0.0 && nv1.y == 0.0 && nv1.z == 0.0
      && nv2.x == 0.0 && nv2.y != 0.0 && nv2.z == 0.0
      && nv3.x == 0.0 && nv3.y == 0.0 && nv3.z != 0.0);

  /*
   * compute transformation using normalized basis vectors
   * this transforms the parallelepiped domain preserving lengths
   * of cell basis vectors (shears the system)
   */
  if (fobj->is_orthogonal) {
    /* each of these entries is either 1 or -1 */
    ta1.x = nv1.x;
    ta2.y = nv2.y;
    ta3.z = nv3.z;
  }
  else {
    /* row vectors of inverse transformation */
    MD_Dvec a, b, c;

    /* for computing determinant of inverse transformation matrix */
    double d1, d2, d3, det, inv_det;

    /* take transpose matrix for row vectors */
    a.x = nv1.x;
    b.x = nv1.y;
    c.x = nv1.z;
    a.y = nv2.x;
    b.y = nv2.y;
    c.y = nv2.z;
    a.z = nv3.x;
    b.z = nv3.y;
    c.z = nv3.z;

    /* compute determinant */
    d1 = b.y * c.z - b.z * c.y;
    d2 = b.z * c.x - b.x * c.z;
    d3 = b.x * c.y - b.y * c.x;
    det = a.x * d1 + a.y * d2 + a.z * d3;

    /* make sure cell basis vectors aren't (nearly) linearly dependent */
    if (det*det < TOL_CELL_BASIS) {
      ERROR("cell basis vectors are (nearly) linearly independent");
      return FORCE_FAIL;
    }
    inv_det = 1.0 / det;

    /* compute transformation matrix */
    /* by directly inverting the 3x3 inverse transformation matrix */
    ta1.x = d1 * inv_det;
    ta1.y = (a.z * c.y - a.y * c.z) * inv_det;
    ta1.z = (a.y * b.z - a.z * b.y) * inv_det;
    ta2.x = d2 * inv_det;
    ta2.y = (a.x * c.z - a.z * c.x) * inv_det;
    ta2.z = (a.z * b.x - a.x * b.z) * inv_det;
    ta3.x = d3 * inv_det;
    ta3.y = (a.y * c.x - a.x * c.y) * inv_det;
    ta3.z = (a.x * b.y - a.y * b.x) * inv_det;
  }
  VEC(ta1);
  VEC(ta2);
  VEC(ta3);
  INT(fobj->is_orthogonal);

  /*
   * transform system coordinates aligning lengths with x, y, z axes
   *
   * (if not fully periodic, we will need to determine extent of system
   * to derive a length along nonperiodic cell directions)
   */
  if (fobj->is_orthogonal) {
    /* no transformation needed (identity matrix) so just copy */
    for (i = 0;  i < sel_len;  i++) {
      j = sel[i];
      trpos[j] = pos[j];
    }
  }
  else {
    /* transform atoms in system */
    for (i = 0;  i < sel_len;  i++) {
      j = sel[i];
      trpos[j].x = ta1.x*pos[j].x + ta1.y*pos[j].y + ta1.z*pos[j].z;
      trpos[j].y = ta2.x*pos[j].x + ta2.y*pos[j].y + ta2.z*pos[j].z;
      trpos[j].z = ta3.x*pos[j].x + ta3.y*pos[j].y + ta3.z*pos[j].z;
    }
  }
  VEC(trpos[0]);

  /*
   * determine lengths of cell basis vectors
   *
   * (must look at transformed system along nonperiodic cell directions
   * to determine those vector lengths)
   */
  if (fobj->is_periodic & FORCE_X_PERIODIC) {
    fobj->lv1 = sqrt(v1.x*v1.x + v1.y*v1.y + v1.z*v1.z);
  }
  else {
    /* find length for cell basis vector when transformed to x-direction */
    min = max = trpos[sel[0]].x;
    for (i = 1;  i < sel_len;  i++) {
      j = sel[i];
      if (min > trpos[j].x) min = trpos[j].x;
      else if (max < trpos[j].x) max = trpos[j].x;
    }
    len = (max - min) * fobj->domain_len_frac;
    v1.x = nv1.x * len;
    v1.y = nv1.y * len;
    v1.z = nv1.z * len;
    fobj->lv1 = len;
  }

  if (fobj->is_periodic & FORCE_Y_PERIODIC) {
    fobj->lv2 = sqrt(v2.x*v2.x + v2.y*v2.y + v2.z*v2.z);
  }
  else {
    /* find length for cell basis vector when transformed to y-direction */
    min = max = trpos[sel[0]].y;
    for (i = 1;  i < sel_len;  i++) {
      j = sel[i];
      if (min > trpos[j].y) min = trpos[j].y;
      else if (max < trpos[j].y) max = trpos[j].y;
    }
    len = (max - min) * fobj->domain_len_frac;
    v2.x = nv2.x * len;
    v2.y = nv2.y * len;
    v2.z = nv2.z * len;
    fobj->lv2 = len;
  }

  if (fobj->is_periodic & FORCE_Z_PERIODIC) {
    fobj->lv3 = sqrt(v3.x*v3.x + v3.y*v3.y + v3.z*v3.z);
  }
  else {
    /* find length for cell basis vector when transformed to z-direction */
    min = max = trpos[sel[0]].z;
    for (i = 1;  i < sel_len;  i++) {
      j = sel[i];
      if (min > trpos[j].z) min = trpos[j].z;
      else if (max < trpos[j].z) max = trpos[j].z;
    }
    len = (max - min) * fobj->domain_len_frac;
    v3.x = nv3.x * len;
    v3.y = nv3.y * len;
    v3.z = nv3.z * len;
    fobj->lv3 = len;
  }

  /* save cell basis vectors */
  fobj->v1 = v1;
  fobj->v2 = v2;
  fobj->v3 = v3;

  /* initialize offset table */
  for (k = -1;  k <= 1;  k++) {
    for (j = -1;  j <= 1;  j++) {
      for (i = -1;  i <= 1;  i++) {
        off.x = -i*v1.x + -j*v2.x + -k*v3.x;
        off.y = -i*v1.y + -j*v2.y + -k*v3.y;
        off.z = -i*v1.z + -j*v2.z + -k*v3.z;
        indx = OFFSET_INDEX(i, j, k);
        fobj->offset_table[indx] = off;
      }
    }
  }

  /* calculate volume of domain (only meaningful for full periodicity) */
  fobj->volume = (v1.y * v2.z - v1.z * v2.y) * v3.x
    + (v1.z * v2.x - v1.x * v2.z) * v3.y
    + (v1.x * v2.y - v1.y * v2.x) * v3.z;

  /* scale transformation to map periodic cell into unit cell */
  inv_lv1 = 1.0 / fobj->lv1;
  fobj->ta1.x = ta1.x * inv_lv1;
  fobj->ta1.y = ta1.y * inv_lv1;
  fobj->ta1.z = ta1.z * inv_lv1;

  inv_lv2 = 1.0 / fobj->lv2;
  fobj->ta2.x = ta2.x * inv_lv2;
  fobj->ta2.y = ta2.y * inv_lv2;
  fobj->ta2.z = ta2.z * inv_lv2;

  inv_lv3 = 1.0 / fobj->lv3;
  fobj->ta3.x = ta3.x * inv_lv3;
  fobj->ta3.y = ta3.y * inv_lv3;
  fobj->ta3.z = ta3.z * inv_lv3;

  /* save center */
  fobj->center = fdom->center;

  /* compute lower left-hand corner of cell */
  fobj->lowerc.x = fobj->center.x - 0.5 * (v1.x + v2.x + v3.x);
  fobj->lowerc.y = fobj->center.y - 0.5 * (v1.y + v2.y + v3.y);
  fobj->lowerc.z = fobj->center.z - 0.5 * (v1.z + v2.z + v3.z);

  /*
   * finish computing trpos and poswrap
   *
   * the affine transformation is:
   *   s_i = A^{-1} (r_i - l + w_i)
   * where s_i is transformed position
   *       r_i is atomic coordinate
   *       l   is lower left-hand corner
   *       w_i is wrapping into periodic cell for r_i
   *       A = [ v1 v2 v3 ]
   *
   * we have already computed A^{-1} r_i up to a scaling factor,
   * i.e. B^{-1} r_i where B = [ v1/|v1| v2/|v2| v3/|v3| ]
   */
  trlow.x = fobj->ta1.x * fobj->lowerc.x + fobj->ta1.y * fobj->lowerc.y
    + fobj->ta1.z * fobj->lowerc.z;
  trlow.y = fobj->ta2.x * fobj->lowerc.x + fobj->ta2.y * fobj->lowerc.y
    + fobj->ta2.z * fobj->lowerc.z;
  trlow.z = fobj->ta3.x * fobj->lowerc.x + fobj->ta3.y * fobj->lowerc.y
    + fobj->ta3.z * fobj->lowerc.z;

  for (i = 0;  i < sel_len;  i++) {
    j = sel[i];
    trpos[j].x *= inv_lv1;
    trpos[j].y *= inv_lv2;
    trpos[j].z *= inv_lv3;
    trpos[j].x -= trlow.x;
    trpos[j].y -= trlow.y;
    trpos[j].z -= trlow.z;

    if (fobj->is_periodic & FORCE_X_PERIODIC) {
      if (trpos[j].x < 0.0) {
        do {
          trpos[j].x += 1.0;
          wrap[j].x += v1.x;
          wrap[j].y += v1.y;
          wrap[j].z += v1.z;
        } while (trpos[j].x < 0.0);
      }
      else if (trpos[j].x >= 1.0) {
        do {
          trpos[j].x -= 1.0;
          wrap[j].x -= v1.x;
          wrap[j].y -= v1.y;
          wrap[j].z -= v1.z;
        } while (trpos[j].x >= 1.0);
      }
    }

    if (fobj->is_periodic & FORCE_Y_PERIODIC) {
      if (trpos[j].y < 0.0) {
        do {
          trpos[j].y += 1.0;
          wrap[j].x += v2.x;
          wrap[j].y += v2.y;
          wrap[j].z += v2.z;
        } while (trpos[j].y < 0.0);
      }
      else if (trpos[j].y >= 1.0) {
        do {
          trpos[j].y -= 1.0;
          wrap[j].x -= v2.x;
          wrap[j].y -= v2.y;
          wrap[j].z -= v2.z;
        } while (trpos[j].y >= 1.0);
      }
    }

    if (fobj->is_periodic & FORCE_Z_PERIODIC) {
      if (trpos[j].z < 0.0) {
        do {
          trpos[j].z += 1.0;
          wrap[j].x += v3.x;
          wrap[j].y += v3.y;
          wrap[j].z += v3.z;
        } while (trpos[j].z < 0.0);
      }
      else if (trpos[j].z >= 1.0) {
        do {
          trpos[j].z -= 1.0;
          wrap[j].x -= v3.x;
          wrap[j].y -= v3.y;
          wrap[j].z -= v3.z;
        } while (trpos[j].z >= 1.0);
      }
    }

  } /* end finish computing trpos and poswrap */

  INT(sel_len);
  VEC(trpos[0]);

  return 0;
} /* end force_setup_domain() */


void force_cleanup_domain(Force *f)
{
  /* nothing to do!  no memory was allocated */
}


int force_setup_lattice(Force *f, int32 k1, int32 k2, int32 k3)
{
  if (k1 == 0 && k2 == 0 && k3 == 0) {
    const MD_Dvec zero = { 0.0, 0.0, 0.0 };
    f->k1 = k1;
    f->k2 = k2;
    f->k3 = k3;
    f->dv1 = zero;
    f->dv2 = zero;
    f->dv3 = zero;
    f->tb1 = zero;
    f->tb2 = zero;
    f->tb3 = zero;
    f->is_fixed_lattice = 0;
  }
  else if (k1 <= 0 || k2 <= 0 || k3 <= 0) {
    return FORCE_FAIL;
  }
  else {
    /* row vectors of transformation */
    MD_Dvec a, b, c;

    /* row vectors of inverse transformation matrix */
    MD_Dvec tb1, tb2, tb3;

    /* for computing determinant of transformation matrix */
    double d1, d2, d3, det, inv_det;

    double inv_k1 = 1.0 / (double) k1;
    double inv_k2 = 1.0 / (double) k2;
    double inv_k3 = 1.0 / (double) k3;

    /* take transpose matrix for row vectors */
    a.x = f->dv1.x = f->v1.x * inv_k1;
    b.x = f->dv1.y = f->v1.y * inv_k1;
    c.x = f->dv1.z = f->v1.z * inv_k1;
    a.y = f->dv2.x = f->v2.x * inv_k2;
    b.y = f->dv2.y = f->v2.y * inv_k2;
    c.y = f->dv2.z = f->v2.z * inv_k2;
    a.z = f->dv3.x = f->v3.x * inv_k3;
    b.z = f->dv3.y = f->v3.y * inv_k3;
    c.z = f->dv3.z = f->v3.z * inv_k3;

    /* compute determinant */
    d1 = b.y * c.z - b.z * c.y;
    d2 = b.z * c.x - b.x * c.z;
    d3 = b.x * c.y - b.y * c.x;
    det = a.x * d1 + a.y * d2 + a.z * d3;

    /* works with assumption that basis vectors are linearly independent */
    inv_det = 1.0 / det;

    /* compute inverse transformation matrix */
    /* by directly inverting the 3x3 transformation matrix */
    tb1.x = d1 * inv_det;
    tb1.y = (a.z * c.y - a.y * c.z) * inv_det;
    tb1.z = (a.y * b.z - a.z * b.y) * inv_det;
    tb2.x = d2 * inv_det;
    tb2.y = (a.x * c.z - a.z * c.x) * inv_det;
    tb2.z = (a.z * b.x - a.x * b.z) * inv_det;
    tb3.x = d3 * inv_det;
    tb3.y = (a.y * c.x - a.x * c.y) * inv_det;
    tb3.z = (a.x * b.y - a.y * b.x) * inv_det;

    f->tb1 = tb1;
    f->tb2 = tb2;
    f->tb3 = tb3;

    f->is_fixed_lattice = 1;
  }

  return 0;
}
