#if defined(VMDMOJIT)

//
// If we're testing performance of a JIT kernel, include the code here
//
#if defined(VMDMOJITSRC)
#include VMDMOJITSRC
#endif

//
// OpenCL JIT code generator for constant memory
//
int vmd_opencl_jit_generate(char **srcstr, int numatoms,
                            const float *wave_f, const float *basis_array,
                            const int *atom_basis,
                            const int *num_shells_per_atom,
                            const int *num_prim_per_shell,
                            const int *shell_types) {
  srcbuf[4096];

  // calculate the value of the wavefunction of the
  // selected orbital at the current grid point
  int at;
  int prim, shell;

  // initialize the wavefunction and shell counters
  int shell_counter = 0;

  fprints(srcbuf, 
    "__kernel__ static void clorbitalconstmem_jit(int numatoms,\n"
    "                          float voxelsize,\n"
    "                          float originx,\n"
    "                          float originy,\n"
    "                          float grid_z, \n"
    "                          int density, \n"
    "                          float * orbitalgrid) {\n"
    "  unsigned int xindex  = get_global_id(0);\n"
    "  unsigned int yindex  = get_global_id(1);\n"
    "  unsigned int outaddr = get_global_size(0) * yindex + xindex;\n"
    "  float grid_x = originx + voxelsize * xindex;\n"
    "  float grid_y = originy + voxelsize * yindex;\n"
 
    "  // similar to C version\n"
    "  int at;\n"
    "  // initialize value of orbital at gridpoint\n"
    "  float value = 0.0f;\n"
    "  // initialize the wavefunction and shell counters\n"
    "  int ifunc = 0;\n"
    "  // loop over all the QM atoms\n"
    "  for (at = 0; at < numatoms; at++) {\n"
    "    // calculate distance between grid point and center of atom\n"
//    "    int maxshell = const_num_shells_per_atom[at];\n"
//    "    int prim_counter = const_atom_basis[at];\n"
    "    float xdist = (grid_x - const_atompos[3*at  ])*ANGS_TO_BOHR;\n"
    "    float ydist = (grid_y - const_atompos[3*at+1])*ANGS_TO_BOHR;\n"
    "    float zdist = (grid_z - const_atompos[3*at+2])*ANGS_TO_BOHR;\n"
    "    float xdist2 = xdist*xdist;\n"
    "    float ydist2 = ydist*ydist;\n"
    "    float zdist2 = zdist*zdist;\n"
    "    float dist2 = xdist2 + ydist2 + zdist2;\n"
    "    float contracted_gto=0.0f;\n"
    "    float tmpshell=0.0f;\n"
    "\n"
  );

#if 0
  // loop over all the QM atoms generating JIT code for each type
  for (at=0; at<numatoms; at++) {
#else
  // generate JIT code for one atom type and assume they are all the same
  for (at=0; at<1; at++) {
#endif
    int maxshell = num_shells_per_atom[at];
    int prim_counter = atom_basis[at];

    // loop over the shells belonging to this atom
    for (shell=0; shell < maxshell; shell++) {
      // Loop over the Gaussian primitives of this contracted
      // basis function to build the atomic orbital
      int maxprim = num_prim_per_shell[shell_counter];
      int shelltype = shell_types[shell_counter];
      for (prim=0; prim<maxprim; prim++) {
        float exponent       = basis_array[prim_counter    ];
        float contract_coeff = basis_array[prim_counter + 1];
        if (prim == 0) {
          fprints(srcbuf, "    contracted_gto = %ff * expf(-%ff*dist2);\n",
                  contract_coeff, exponent);
        } else {
          fprints(srcbuf, "    contracted_gto += %ff * expf(-%ff*dist2);\n",
                  contract_coeff, exponent);
        }
        prim_counter += 2;
      }

      /* multiply with the appropriate wavefunction coefficient */
      switch (shelltype) {
        case S_SHELL:
          fprints(srcbuf, 
            "    // S_SHELL\n"
            "    value += const_wave_f[ifunc++] * contracted_gto;\n");
          break;

        case P_SHELL:
          fprints(srcbuf,
            "    // P_SHELL\n"
            "    tmpshell = const_wave_f[ifunc++] * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist;\n"
            "    value += tmpshell * contracted_gto;\n"
          );
          break;

        case D_SHELL:
          fprints(srcbuf,
            "    // D_SHELL\n"
            "    tmpshell = const_wave_f[ifunc++] * xdist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2;\n"
            "    value += tmpshell * contracted_gto;\n"
          );
          break;

        case F_SHELL:
          fprints(srcbuf,
            "    // F_SHELL\n"
            "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist * ydist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist;\n"
            "    value += tmpshell * contracted_gto;\n"
          );
          break;

        case G_SHELL:
          fprints(srcbuf,
            "    // G_SHELL\n"
            "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * zdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist2;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * xdist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * ydist;\n"
            "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist2;\n"
            "    value += tmpshell * contracted_gto;\n"
          );
          break;

      } // end switch
      fprints(srcbuf, "\n");

      shell_counter++;
    } // end shell
  } // end atom

  fprints(srcbuf, 
    "  }\n"
    "\n"
    "  // return either orbital density or orbital wavefunction amplitude \n"
    "  if (density) { \n"
    "    orbitalgrid[outaddr] = copysignf(value*value, value); \n"
    "  } else { \n"
    "    orbitalgrid[outaddr] = value; \n"
    "  }\n"
    "}\n"
  );

  if (ofp != stdout)
    fclose(ofp);

  return 0;
}

#endif // VMDMOJIT


