#include "WKFUtils.h"
#include "OpenCLKernels.h"
#include "OpenCLUtils.h"

int write_orbital_data(const char *filename,
                       int numatoms,
                       const float *wave_f, int num_wave_f,
                       const float *basis_array, int num_basis,
                       const float *atompos,
                       const int *atom_basis,
                       const int *num_shells_per_atom,
                       const int *num_prim_per_shell,
                       const int *shell_symmetry,
                       int num_shells,
                       const int *numvoxels,
                       float voxelsize,
                       const float *origin) {
  int i;

  FILE *ofp=fopen(filename, "w");
  if (ofp==NULL)
    return -1;

  printf("dumping orbital data to '%s'\n", filename);

  // dump atom count
  fprintf(ofp, "%d\n", numatoms);

  // dump atom coordinates
  for (i=0; i<(numatoms*3); i++)
    fprintf(ofp, "%f ", atompos[i]);  
  fprintf(ofp, "\n");

  // dump atom_basis array
  for (i=0; i<numatoms; i++)
    fprintf(ofp, "%d ", atom_basis[i]);  
  fprintf(ofp, "\n");

  // dump num_shells_per_atom array
  for (i=0; i<numatoms; i++)
    fprintf(ofp, "%d ", num_shells_per_atom[i]);  
  fprintf(ofp, "\n");

  // dump shell count
  fprintf(ofp, "%d\n", num_shells);

  // dump num_prim_per_shell array
  for (i=0; i<num_shells; i++)
    fprintf(ofp, "%d ", num_prim_per_shell[i]);  
  fprintf(ofp, "\n");

  // dump shell_symmetry array
  for (i=0; i<num_shells; i++)
    fprintf(ofp, "%d ", shell_symmetry[i]);  
  fprintf(ofp, "\n");

  // dump basis array
  fprintf(ofp, "%d\n", num_basis);
  for (i=0; i<(num_basis*2); i++)
    fprintf(ofp, "%f ", basis_array[i]);  
  fprintf(ofp, "\n");

  // dump misc info...
  // origin
  fprintf(ofp, "%f %f %f\n", origin[0], origin[1], origin[2]);
  // grid size
  fprintf(ofp, "%d %d %d\n", numvoxels[0], numvoxels[1], numvoxels[2]);
  // voxel size
  fprintf(ofp, "%f\n", voxelsize);

 
  // dump wavefunction coefficient array
  fprintf(ofp, "%d\n", num_wave_f);
  for (i=0; i<num_wave_f; i++)
    fprintf(ofp, "%f ", wave_f[i]);  
  fprintf(ofp, "\n");

  fclose(ofp);
  return 0;
}

int read_calc_orbitals(wkf_threadpool_t *devpool, const char *filename, int verbose) {
  int numatoms;
  float *wave_f;
  int num_wave_f;
  float *basis_array;
  int num_basis;
  float *atompos;
  int *atom_basis;
  int *num_shells_per_atom;
  int *num_prim_per_shell;
  int *shell_symmetry;
  int num_shells;
  int numvoxels[3];
  float voxelsize;
  float origin[3];

  int i;

  FILE *ifp=fopen(filename, "r");
  if (ifp==NULL)
    return -1;

  printf("reading orbital data from '%s'\n", filename);

  // read atom count
  fscanf(ifp, "%d\n", &numatoms);
  atompos = (float *) malloc(numatoms * 3 * sizeof(float));
  atom_basis = (int *) malloc(numatoms * sizeof(int));
  num_shells_per_atom = (int *) malloc(numatoms * sizeof(int));

  // read atom coordinates
  for (i=0; i<(numatoms*3); i++)
    fscanf(ifp, "%f ", &atompos[i]);  

  // read atom_basis array
  for (i=0; i<numatoms; i++)
    fscanf(ifp, "%d ", &atom_basis[i]);  

  // read num_shells_per_atom array
  for (i=0; i<numatoms; i++)
    fscanf(ifp, "%d ", &num_shells_per_atom[i]);  

  // read shell count
  fscanf(ifp, "%d\n", &num_shells);
  num_prim_per_shell = (int *) malloc(num_shells * sizeof(int));
  shell_symmetry = (int *) malloc(num_shells * sizeof(int));

  // read num_prim_per_shell array
  for (i=0; i<num_shells; i++)
    fscanf(ifp, "%d ", &num_prim_per_shell[i]);  

  // read shell_symmetry array
  for (i=0; i<num_shells; i++)
    fscanf(ifp, "%d ", &shell_symmetry[i]);  

  // read basis array
  fscanf(ifp, "%d\n", &num_basis);
  basis_array = (float *) malloc(num_basis * 2 * sizeof(float)); 
  for (i=0; i<(num_basis*2); i++)
    fscanf(ifp, "%f ", &basis_array[i]);  

  // read misc info...
  // origin
  fscanf(ifp, "%f %f %f\n", &origin[0], &origin[1], &origin[2]);
  // grid size
  fscanf(ifp, "%d %d %d\n", &numvoxels[0], &numvoxels[1], &numvoxels[2]);
  // voxel size
  fscanf(ifp, "%f\n", &voxelsize);
 
  // read wavefunction coefficient array
  fscanf(ifp, "%d\n", &num_wave_f);
  wave_f = (float *) malloc(num_wave_f * sizeof(float));
  for (i=0; i<num_wave_f; i++)
    fscanf(ifp, "%f ", &wave_f[i]);  

  fclose(ifp);

  // compute orbitals
#if 1
  float *orbitalgrid = (float *) malloc(numvoxels[0]*numvoxels[1]*numvoxels[2]*sizeof(float));
  wkf_timerhandle timer;
  timer=wkf_timer_create();
  wkf_timer_start(timer);
  int rc=-1;

  if (!getenv("VMDNOOPENCL")) {
    // XXX this would be done during app startup normally...
    static vmd_opencl_orbital_handle *orbh = NULL;
    static cl_context clctx = NULL;
    static cl_command_queue clcmdq = NULL;
    static cl_device_id *cldevs = NULL;
    if (orbh == NULL) {
      printf("Attaching OpenCL device:\n");
      vmd_cl_print_platform_info();
       
      wkf_timer_start(timer);
      cl_int clerr = CL_SUCCESS;

      cl_platform_id clplatid = vmd_cl_get_platform_index(0);
      cl_context_properties clctxprops[] = {(cl_context_properties) CL_CONTEXT_PLATFORM, (cl_context_properties) clplatid, (cl_context_properties) 0};
      clctx = clCreateContextFromType(clctxprops, CL_DEVICE_TYPE_GPU, NULL, NULL, &clerr);

      if (clerr != CL_SUCCESS) { printf("opencl error %d, %s line %d\n", clerr, __FILE__, __LINE__); return -1; }
      size_t parmsz;
      clerr |= clGetContextInfo(clctx, CL_CONTEXT_DEVICES, 0, NULL, &parmsz);
      cldevs = (cl_device_id *) malloc(parmsz);
      if (clerr != CL_SUCCESS) { printf("opencl error %d, %s line %d\n", clerr, __FILE__, __LINE__); return -1; }
      clerr |= clGetContextInfo(clctx, CL_CONTEXT_DEVICES, parmsz, cldevs, NULL);
      if (clerr != CL_SUCCESS) { printf("opencl error %d, %s line %d\n", clerr, __FILE__, __LINE__); return -1; }
      clcmdq = clCreateCommandQueue(clctx, cldevs[0], 0, &clerr);
      if (clerr != CL_SUCCESS) { printf("opencl error %d, %s line %d\n", clerr, __FILE__, __LINE__); return -1; }
      wkf_timer_stop(timer);
      printf("  OpenCL context creation time: %.3f sec\n", wkf_timer_time(timer));

      wkf_timer_start(timer);
      orbh = vmd_opencl_create_orbital_handle(clctx, clcmdq, cldevs);
      wkf_timer_stop(timer);
      printf("  OpenCL kernel compilation time: %.3f sec\n", wkf_timer_time(timer));

      wkf_timer_start(timer);
    }

    // run the kernel
    rc = vmd_opencl_evaluate_orbital_grid(
               devpool, orbh, numatoms,
               wave_f, num_wave_f, basis_array, num_basis,
               atompos, atom_basis, num_shells_per_atom,
               num_prim_per_shell, shell_symmetry,
               num_shells, numvoxels, voxelsize, origin,
               0,
               orbitalgrid);

#if 0
    // XXX this would normally be done at shutdown
    vmd_opencl_destroy_orbital_handle(parms.orbh);
    clReleaseCommandQueue(clcmdq);
    clReleaseContext(clctx);
    free(cldevs);
#endif
  } 


#if defined(CUDA)
  if (rc) {
    rc = vmd_cuda_evaluate_orbital_grid(
               devpool, numatoms,
               wave_f, num_wave_f, basis_array, num_basis,
               atompos, atom_basis, num_shells_per_atom,
               num_prim_per_shell, shell_symmetry,
               num_shells, numvoxels, voxelsize, origin,
               0,
               orbitalgrid);
  }
#endif

  wkf_timer_stop(timer);
  printf("Orbital runtime: %f\n", wkf_timer_time(timer));
  wkf_timer_destroy(timer);

#if 1
  if (verbose) {
    int i, j;
    int starti=0;
    int startj=0;
    int endi=starti+6;
    int endj=startj+8;
    printf("Corner of output orbital grid for correctness checking:\n");
    for (j=startj; j<endj; j++) {
      for (i=starti; i<endi; i++) {
        int addr = j*numvoxels[0] + i;
        printf("%g ", orbitalgrid[addr]);
      }
      printf("\n");
    }
  }
#endif

  free(orbitalgrid);

#else
  // dump again to diff against input data for correctness checking
  int rc = write_orbital_data("/tmp/comporbs.txt", numatoms,
             wave_f, num_wave_f, basis_array, num_basis,
             atompos, atom_basis, num_shells_per_atom,
             num_prim_per_shell, shell_symmetry,
             num_shells, numvoxels, voxelsize, origin);
#endif

  // XXX need to free all that stuff still...
  free(wave_f);
  free(basis_array);
  free(atompos);
  free(atom_basis);
  free(num_shells_per_atom);
  free(num_prim_per_shell);
  free(shell_symmetry);

  return rc;
}

