Main Page   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Related Pages  

OrbitalJIT.C

Go to the documentation of this file.
00001 /***************************************************************************
00002  *cr
00003  *cr            (C) Copyright 1995-2019 The Board of Trustees of the
00004  *cr                        University of Illinois
00005  *cr                         All Rights Reserved
00006  *cr
00007  ***************************************************************************/
00008 /***************************************************************************
00009  * RCS INFORMATION:
00010  *
00011  *      $RCSfile: OrbitalJIT.C,v $
00012  *      $Author: johns $        $Locker:  $             $State: Exp $
00013  *      $Revision: 1.6 $      $Date: 2020/02/26 06:00:57 $
00014  *
00015  ***************************************************************************/
00022 #include <stdio.h>
00023 #include <stdlib.h>
00024 #include <string.h>
00025 #include <math.h>
00026 #include "WKFThreads.h"
00027 #include "OrbitalJIT.h"
00028 
00029 // Only enabled for testing an emulation of JIT code generation approach
00030 // #define VMDMOJIT 1
00031 // #define VMDMOJITSRC "/home/johns/mojit.cu"
00032 
00033 #define ANGS_TO_BOHR 1.8897259877218677f
00034 
00035 // GPU block layout
00036 #define UNROLLX       1
00037 #define UNROLLY       1
00038 #define BLOCKSIZEX    8 
00039 #define BLOCKSIZEY    8
00040 #define BLOCKSIZE     BLOCKSIZEX * BLOCKSIZEY
00041 
00042 // required GPU array padding to match thread block size
00043 #define TILESIZEX BLOCKSIZEX*UNROLLX
00044 #define TILESIZEY BLOCKSIZEY*UNROLLY
00045 #define GPU_X_ALIGNMASK (TILESIZEX - 1)
00046 #define GPU_Y_ALIGNMASK (TILESIZEY - 1)
00047 
00048 #define MEMCOALESCE  384
00049 
00050 // orbital shell types 
00051 #define S_SHELL 0
00052 #define P_SHELL 1
00053 #define D_SHELL 2
00054 #define F_SHELL 3
00055 #define G_SHELL 4
00056 #define H_SHELL 5
00057 
00058 //
00059 // CUDA constant arrays 
00060 //
00061 #define MAX_ATOM_SZ 256
00062 
00063 #define MAX_ATOMPOS_SZ (MAX_ATOM_SZ)
00064 // __constant__ static float const_atompos[MAX_ATOMPOS_SZ * 3];
00065 
00066 #define MAX_ATOM_BASIS_SZ (MAX_ATOM_SZ)
00067 // __constant__ static int const_atom_basis[MAX_ATOM_BASIS_SZ];
00068 
00069 #define MAX_ATOMSHELL_SZ (MAX_ATOM_SZ)
00070 // __constant__ static int const_num_shells_per_atom[MAX_ATOMSHELL_SZ];
00071 
00072 #define MAX_BASIS_SZ 6144 
00073 // __constant__ static float const_basis_array[MAX_BASIS_SZ];
00074 
00075 #define MAX_SHELL_SZ 1024
00076 // __constant__ static int const_num_prim_per_shell[MAX_SHELL_SZ];
00077 // __constant__ static int const_shell_types[MAX_SHELL_SZ];
00078 
00079 #define MAX_WAVEF_SZ 6144
00080 // __constant__ static float const_wave_f[MAX_WAVEF_SZ];
00081 
00082 
00083 //
00084 // CUDA JIT code generator for constant memory
00085 //
00086 int orbital_jit_generate(int jitlanguage,
00087                          const char * srcfilename, int numatoms,
00088                          const float *wave_f, const float *basis_array,
00089                          const int *atom_basis,
00090                          const int *num_shells_per_atom,
00091                          const int *num_prim_per_shell,
00092                          const int *shell_types) {
00093   FILE *ofp=NULL;
00094   if (srcfilename) 
00095     ofp=fopen(srcfilename, "w");
00096 
00097   if (ofp == NULL)
00098     ofp=stdout; 
00099 
00100   // calculate the value of the wavefunction of the
00101   // selected orbital at the current grid point
00102   int at;
00103   int prim, shell;
00104 
00105   // initialize the wavefunction and shell counters
00106   int shell_counter = 0;
00107 
00108   if (jitlanguage == ORBITAL_JIT_CUDA) {
00109     fprintf(ofp, 
00110       "__global__ static void cuorbitalconstmem_jit(int numatoms,\n"
00111       "                          float voxelsize,\n"
00112       "                          float originx,\n"
00113       "                          float originy,\n"
00114       "                          float grid_z, \n"
00115       "                          int density, \n"
00116       "                          float * orbitalgrid) {\n"
00117       "  unsigned int xindex  = blockIdx.x * blockDim.x + threadIdx.x;\n"
00118       "  unsigned int yindex  = blockIdx.y * blockDim.y + threadIdx.y;\n"
00119       "  unsigned int outaddr = gridDim.x * blockDim.x * yindex + xindex;\n"
00120     );
00121   } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00122     fprintf(ofp, 
00123       "// unit conversion                                                  \n"
00124       "#define ANGS_TO_BOHR 1.8897259877218677f                            \n"
00125     );
00126 
00127     fprintf(ofp, "__kernel __attribute__((reqd_work_group_size(%d, %d, 1)))\n",
00128             BLOCKSIZEX, BLOCKSIZEY);
00129 
00130     fprintf(ofp, 
00131       "void clorbitalconstmem_jit(int numatoms,                            \n"
00132       "                       __constant float *const_atompos,             \n"
00133       "                       __constant float *const_wave_f,              \n"
00134       "                       float voxelsize,                             \n"
00135       "                       float originx,                               \n"
00136       "                       float originy,                               \n"
00137       "                       float grid_z,                                \n"
00138       "                       int density,                                 \n"
00139       "                       __global float * orbitalgrid) {              \n"
00140       "  unsigned int xindex  = get_global_id(0);                          \n"
00141       "  unsigned int yindex  = get_global_id(1);                          \n"
00142       "  unsigned int outaddr = get_global_size(0) * yindex + xindex;      \n"
00143     );
00144   }
00145 
00146   fprintf(ofp, 
00147     "  float grid_x = originx + voxelsize * xindex;\n"
00148     "  float grid_y = originy + voxelsize * yindex;\n"
00149  
00150     "  // similar to C version\n"
00151     "  int at;\n"
00152     "  // initialize value of orbital at gridpoint\n"
00153     "  float value = 0.0f;\n"
00154     "  // initialize the wavefunction and shell counters\n"
00155     "  int ifunc = 0;\n"
00156     "  // loop over all the QM atoms\n"
00157     "  for (at = 0; at < numatoms; at++) {\n"
00158     "    // calculate distance between grid point and center of atom\n"
00159 //    "    int maxshell = const_num_shells_per_atom[at];\n"
00160 //    "    int prim_counter = const_atom_basis[at];\n"
00161     "    float xdist = (grid_x - const_atompos[3*at  ])*ANGS_TO_BOHR;\n"
00162     "    float ydist = (grid_y - const_atompos[3*at+1])*ANGS_TO_BOHR;\n"
00163     "    float zdist = (grid_z - const_atompos[3*at+2])*ANGS_TO_BOHR;\n"
00164     "    float xdist2 = xdist*xdist;\n"
00165     "    float ydist2 = ydist*ydist;\n"
00166     "    float zdist2 = zdist*zdist;\n"
00167     "    float dist2 = xdist2 + ydist2 + zdist2;\n"
00168     "    float contracted_gto=0.0f;\n"
00169     "    float tmpshell=0.0f;\n"
00170     "\n"
00171   );
00172 
00173 #if 0
00174   // loop over all the QM atoms generating JIT code for each type
00175   for (at=0; at<numatoms; at++) {
00176 #else
00177   // generate JIT code for one atom type and assume they are all the same
00178   for (at=0; at<1; at++) {
00179 #endif
00180     int maxshell = num_shells_per_atom[at];
00181     int prim_counter = atom_basis[at];
00182 
00183     // loop over the shells belonging to this atom
00184     for (shell=0; shell < maxshell; shell++) {
00185       // Loop over the Gaussian primitives of this contracted
00186       // basis function to build the atomic orbital
00187       int maxprim = num_prim_per_shell[shell_counter];
00188       int shelltype = shell_types[shell_counter];
00189       for (prim=0; prim<maxprim; prim++) {
00190         float exponent       = basis_array[prim_counter    ];
00191         float contract_coeff = basis_array[prim_counter + 1];
00192 #if 1
00193         if (jitlanguage == ORBITAL_JIT_CUDA) {
00194           if (prim == 0) {
00195             fprintf(ofp,"    contracted_gto = %ff * exp2f(-%ff*dist2);\n",
00196                     contract_coeff, exponent);
00197           } else {
00198             fprintf(ofp,"    contracted_gto += %ff * exp2f(-%ff*dist2);\n",
00199                     contract_coeff, exponent);
00200           }
00201         } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00202           if (prim == 0) {
00203             fprintf(ofp,"    contracted_gto = %ff * native_exp2(-%ff*dist2);\n",
00204                     contract_coeff, exponent);
00205           } else {
00206             fprintf(ofp,"    contracted_gto += %ff * native_exp2(-%ff*dist2);\n",
00207                     contract_coeff, exponent);
00208           }
00209         }
00210 #else
00211         if (jitlanguage == ORBITAL_JIT_CUDA) {
00212           if (prim == 0) {
00213             fprintf(ofp,"    contracted_gto = %ff * expf(-%ff*dist2);\n",
00214                     contract_coeff, exponent);
00215           } else {
00216             fprintf(ofp,"    contracted_gto += %ff * expf(-%ff*dist2);\n",
00217                     contract_coeff, exponent);
00218           }
00219         } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00220           if (prim == 0) {
00221             fprintf(ofp,"    contracted_gto = %ff * native_exp(-%ff*dist2);\n",
00222                     contract_coeff, exponent);
00223           } else {
00224             fprintf(ofp,"    contracted_gto += %ff * native_exp(-%ff*dist2);\n",
00225                     contract_coeff, exponent);
00226           }
00227         }
00228 #endif
00229         prim_counter += 2;
00230       }
00231 
00232       /* multiply with the appropriate wavefunction coefficient */
00233       switch (shelltype) {
00234         case S_SHELL:
00235           fprintf(ofp, 
00236             "    // S_SHELL\n"
00237             "    value += const_wave_f[ifunc++] * contracted_gto;\n");
00238           break;
00239 
00240         case P_SHELL:
00241           fprintf(ofp,
00242             "    // P_SHELL\n"
00243             "    tmpshell = const_wave_f[ifunc++] * xdist;\n"
00244             "    tmpshell += const_wave_f[ifunc++] * ydist;\n"
00245             "    tmpshell += const_wave_f[ifunc++] * zdist;\n"
00246             "    value += tmpshell * contracted_gto;\n"
00247           );
00248           break;
00249 
00250         case D_SHELL:
00251           fprintf(ofp,
00252             "    // D_SHELL\n"
00253             "    tmpshell = const_wave_f[ifunc++] * xdist2;\n"
00254             "    tmpshell += const_wave_f[ifunc++] * xdist * ydist;\n"
00255             "    tmpshell += const_wave_f[ifunc++] * ydist2;\n"
00256             "    tmpshell += const_wave_f[ifunc++] * xdist * zdist;\n"
00257             "    tmpshell += const_wave_f[ifunc++] * ydist * zdist;\n"
00258             "    tmpshell += const_wave_f[ifunc++] * zdist2;\n"
00259             "    value += tmpshell * contracted_gto;\n"
00260           );
00261           break;
00262 
00263         case F_SHELL:
00264           fprintf(ofp,
00265             "    // F_SHELL\n"
00266             "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist;\n"
00267             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist;\n"
00268             "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist;\n"
00269             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist;\n"
00270             "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist;\n"
00271             "    tmpshell += const_wave_f[ifunc++] * xdist * ydist * zdist;\n"
00272             "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist;\n"
00273             "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist;\n"
00274             "    tmpshell += const_wave_f[ifunc++] * zdist2 * ydist;\n"
00275             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist;\n"
00276             "    value += tmpshell * contracted_gto;\n"
00277           );
00278           break;
00279 
00280         case G_SHELL:
00281           fprintf(ofp,
00282             "    // G_SHELL\n"
00283             "    tmpshell = const_wave_f[ifunc++] * xdist2 * xdist2;\n"
00284             "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * ydist;\n"
00285             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist2;\n"
00286             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * xdist;\n"
00287             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist2;\n"
00288             "    tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * zdist;\n"
00289             "    tmpshell += const_wave_f[ifunc++] * xdist2 * ydist * zdist;\n"
00290             "    tmpshell += const_wave_f[ifunc++] * ydist2 * xdist * zdist;\n"
00291             "    tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * zdist;\n"
00292             "    tmpshell += const_wave_f[ifunc++] * xdist2 * zdist2;\n"
00293             "    tmpshell += const_wave_f[ifunc++] * zdist2 * xdist * ydist;\n"
00294             "    tmpshell += const_wave_f[ifunc++] * ydist2 * zdist2;\n"
00295             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * xdist;\n"
00296             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * ydist;\n"
00297             "    tmpshell += const_wave_f[ifunc++] * zdist2 * zdist2;\n"
00298             "    value += tmpshell * contracted_gto;\n"
00299           );
00300           break;
00301 
00302       } // end switch
00303       fprintf(ofp, "\n");
00304 
00305       shell_counter++;
00306     } // end shell
00307   } // end atom
00308 
00309   fprintf(ofp, 
00310     "  }\n"
00311     "\n"
00312     "  // return either orbital density or orbital wavefunction amplitude \n"
00313     "  if (density) { \n"
00314   );
00315 
00316   if (jitlanguage == ORBITAL_JIT_CUDA) {
00317     fprintf(ofp, "    orbitalgrid[outaddr] = copysignf(value*value, value);\n");
00318   } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00319     fprintf(ofp, "    orbitalgrid[outaddr] = copysign(value*value, value);\n");
00320   }
00321 
00322   fprintf(ofp, 
00323     "  } else { \n"
00324     "    orbitalgrid[outaddr] = value; \n"
00325     "  }\n"
00326     "}\n"
00327   );
00328 
00329   if (ofp != stdout)
00330     fclose(ofp);
00331 
00332   return 0;
00333 }
00334 
00335 
00336 

Generated on Thu Apr 18 02:45:18 2024 for VMD (current) by doxygen1.2.14 written by Dimitri van Heesch, © 1997-2002