00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include <stdio.h>
00024 #include <stdlib.h>
00025 #include <string.h>
00026 #include <math.h>
00027
00028 #include "WKFThreads.h"
00029 #include "OrbitalJIT.h"
00030
00031
00032
00033
00034
00035 #define ANGS_TO_BOHR 1.8897259877218677f
00036
00037
00038 #define UNROLLX 1
00039 #define UNROLLY 1
00040 #define BLOCKSIZEX 8
00041 #define BLOCKSIZEY 8
00042 #define BLOCKSIZE BLOCKSIZEX * BLOCKSIZEY
00043
00044
00045 #define TILESIZEX BLOCKSIZEX*UNROLLX
00046 #define TILESIZEY BLOCKSIZEY*UNROLLY
00047 #define GPU_X_ALIGNMASK (TILESIZEX - 1)
00048 #define GPU_Y_ALIGNMASK (TILESIZEY - 1)
00049
00050 #define MEMCOALESCE 384
00051
00052
00053 #define S_SHELL 0
00054 #define P_SHELL 1
00055 #define D_SHELL 2
00056 #define F_SHELL 3
00057 #define G_SHELL 4
00058 #define H_SHELL 5
00059
00060
00061
00062
00063 #define MAX_ATOM_SZ 256
00064
00065 #define MAX_ATOMPOS_SZ (MAX_ATOM_SZ)
00066
00067
00068 #define MAX_ATOM_BASIS_SZ (MAX_ATOM_SZ)
00069
00070
00071 #define MAX_ATOMSHELL_SZ (MAX_ATOM_SZ)
00072
00073
00074 #define MAX_BASIS_SZ 6144
00075
00076
00077 #define MAX_SHELL_SZ 1024
00078
00079
00080
00081 #define MAX_WAVEF_SZ 6144
00082
00083
00084
00085
00086
00087
00088 int orbital_jit_generate(int jitlanguage,
00089 const char * srcfilename, int numatoms,
00090 const float *wave_f, const float *basis_array,
00091 const int *atom_basis,
00092 const int *num_shells_per_atom,
00093 const int *num_prim_per_shell,
00094 const int *shell_types) {
00095 FILE *ofp=NULL;
00096 if (srcfilename)
00097 ofp=fopen(srcfilename, "w");
00098
00099 if (ofp == NULL)
00100 ofp=stdout;
00101
00102
00103
00104 int at;
00105 int prim, shell;
00106
00107
00108 int shell_counter = 0;
00109
00110 if (jitlanguage == ORBITAL_JIT_CUDA) {
00111 fprintf(ofp,
00112 "__global__ static void cuorbitalconstmem_jit(int numatoms,\n"
00113 " float voxelsize,\n"
00114 " float originx,\n"
00115 " float originy,\n"
00116 " float grid_z, \n"
00117 " int density, \n"
00118 " float * orbitalgrid) {\n"
00119 " unsigned int xindex = __umul24(blockIdx.x, blockDim.x)\n"
00120 " + threadIdx.x;\n"
00121 " unsigned int yindex = __umul24(blockIdx.y, blockDim.y)\n"
00122 " + threadIdx.y;\n"
00123 " unsigned int outaddr = __umul24(gridDim.x, blockDim.x) * yindex\n"
00124 " + xindex;\n"
00125 );
00126 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00127 fprintf(ofp,
00128 "// unit conversion \n"
00129 "#define ANGS_TO_BOHR 1.8897259877218677f \n"
00130 );
00131
00132 fprintf(ofp, "__kernel __attribute__((reqd_work_group_size(%d, %d, 1)))\n",
00133 BLOCKSIZEX, BLOCKSIZEY);
00134
00135 fprintf(ofp,
00136 "void clorbitalconstmem_jit(int numatoms, \n"
00137 " __constant float *const_atompos, \n"
00138 " __constant float *const_wave_f, \n"
00139 " float voxelsize, \n"
00140 " float originx, \n"
00141 " float originy, \n"
00142 " float grid_z, \n"
00143 " int density, \n"
00144 " __global float * orbitalgrid) { \n"
00145 " unsigned int xindex = get_global_id(0); \n"
00146 " unsigned int yindex = get_global_id(1); \n"
00147 " unsigned int outaddr = get_global_size(0) * yindex + xindex; \n"
00148 );
00149 }
00150
00151 fprintf(ofp,
00152 " float grid_x = originx + voxelsize * xindex;\n"
00153 " float grid_y = originy + voxelsize * yindex;\n"
00154
00155 " // similar to C version\n"
00156 " int at;\n"
00157 " // initialize value of orbital at gridpoint\n"
00158 " float value = 0.0f;\n"
00159 " // initialize the wavefunction and shell counters\n"
00160 " int ifunc = 0;\n"
00161 " // loop over all the QM atoms\n"
00162 " for (at = 0; at < numatoms; at++) {\n"
00163 " // calculate distance between grid point and center of atom\n"
00164
00165
00166 " float xdist = (grid_x - const_atompos[3*at ])*ANGS_TO_BOHR;\n"
00167 " float ydist = (grid_y - const_atompos[3*at+1])*ANGS_TO_BOHR;\n"
00168 " float zdist = (grid_z - const_atompos[3*at+2])*ANGS_TO_BOHR;\n"
00169 " float xdist2 = xdist*xdist;\n"
00170 " float ydist2 = ydist*ydist;\n"
00171 " float zdist2 = zdist*zdist;\n"
00172 " float dist2 = xdist2 + ydist2 + zdist2;\n"
00173 " float contracted_gto=0.0f;\n"
00174 " float tmpshell=0.0f;\n"
00175 "\n"
00176 );
00177
00178 #if 0
00179
00180 for (at=0; at<numatoms; at++) {
00181 #else
00182
00183 for (at=0; at<1; at++) {
00184 #endif
00185 int maxshell = num_shells_per_atom[at];
00186 int prim_counter = atom_basis[at];
00187
00188
00189 for (shell=0; shell < maxshell; shell++) {
00190
00191
00192 int maxprim = num_prim_per_shell[shell_counter];
00193 int shelltype = shell_types[shell_counter];
00194 for (prim=0; prim<maxprim; prim++) {
00195 float exponent = basis_array[prim_counter ];
00196 float contract_coeff = basis_array[prim_counter + 1];
00197 #if 1
00198 if (jitlanguage == ORBITAL_JIT_CUDA) {
00199 if (prim == 0) {
00200 fprintf(ofp," contracted_gto = %ff * exp2f(-%ff*dist2);\n",
00201 contract_coeff, exponent);
00202 } else {
00203 fprintf(ofp," contracted_gto += %ff * exp2f(-%ff*dist2);\n",
00204 contract_coeff, exponent);
00205 }
00206 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00207 if (prim == 0) {
00208 fprintf(ofp," contracted_gto = %ff * native_exp2(-%ff*dist2);\n",
00209 contract_coeff, exponent);
00210 } else {
00211 fprintf(ofp," contracted_gto += %ff * native_exp2(-%ff*dist2);\n",
00212 contract_coeff, exponent);
00213 }
00214 }
00215 #else
00216 if (jitlanguage == ORBITAL_JIT_CUDA) {
00217 if (prim == 0) {
00218 fprintf(ofp," contracted_gto = %ff * expf(-%ff*dist2);\n",
00219 contract_coeff, exponent);
00220 } else {
00221 fprintf(ofp," contracted_gto += %ff * expf(-%ff*dist2);\n",
00222 contract_coeff, exponent);
00223 }
00224 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00225 if (prim == 0) {
00226 fprintf(ofp," contracted_gto = %ff * native_exp(-%ff*dist2);\n",
00227 contract_coeff, exponent);
00228 } else {
00229 fprintf(ofp," contracted_gto += %ff * native_exp(-%ff*dist2);\n",
00230 contract_coeff, exponent);
00231 }
00232 }
00233 #endif
00234 prim_counter += 2;
00235 }
00236
00237
00238 switch (shelltype) {
00239 case S_SHELL:
00240 fprintf(ofp,
00241 " // S_SHELL\n"
00242 " value += const_wave_f[ifunc++] * contracted_gto;\n");
00243 break;
00244
00245 case P_SHELL:
00246 fprintf(ofp,
00247 " // P_SHELL\n"
00248 " tmpshell = const_wave_f[ifunc++] * xdist;\n"
00249 " tmpshell += const_wave_f[ifunc++] * ydist;\n"
00250 " tmpshell += const_wave_f[ifunc++] * zdist;\n"
00251 " value += tmpshell * contracted_gto;\n"
00252 );
00253 break;
00254
00255 case D_SHELL:
00256 fprintf(ofp,
00257 " // D_SHELL\n"
00258 " tmpshell = const_wave_f[ifunc++] * xdist2;\n"
00259 " tmpshell += const_wave_f[ifunc++] * xdist * ydist;\n"
00260 " tmpshell += const_wave_f[ifunc++] * ydist2;\n"
00261 " tmpshell += const_wave_f[ifunc++] * xdist * zdist;\n"
00262 " tmpshell += const_wave_f[ifunc++] * ydist * zdist;\n"
00263 " tmpshell += const_wave_f[ifunc++] * zdist2;\n"
00264 " value += tmpshell * contracted_gto;\n"
00265 );
00266 break;
00267
00268 case F_SHELL:
00269 fprintf(ofp,
00270 " // F_SHELL\n"
00271 " tmpshell = const_wave_f[ifunc++] * xdist2 * xdist;\n"
00272 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist;\n"
00273 " tmpshell += const_wave_f[ifunc++] * ydist2 * xdist;\n"
00274 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist;\n"
00275 " tmpshell += const_wave_f[ifunc++] * xdist2 * zdist;\n"
00276 " tmpshell += const_wave_f[ifunc++] * xdist * ydist * zdist;\n"
00277 " tmpshell += const_wave_f[ifunc++] * ydist2 * zdist;\n"
00278 " tmpshell += const_wave_f[ifunc++] * zdist2 * xdist;\n"
00279 " tmpshell += const_wave_f[ifunc++] * zdist2 * ydist;\n"
00280 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist;\n"
00281 " value += tmpshell * contracted_gto;\n"
00282 );
00283 break;
00284
00285 case G_SHELL:
00286 fprintf(ofp,
00287 " // G_SHELL\n"
00288 " tmpshell = const_wave_f[ifunc++] * xdist2 * xdist2;\n"
00289 " tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * ydist;\n"
00290 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist2;\n"
00291 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * xdist;\n"
00292 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist2;\n"
00293 " tmpshell += const_wave_f[ifunc++] * xdist2 * xdist * zdist;\n"
00294 " tmpshell += const_wave_f[ifunc++] * xdist2 * ydist * zdist;\n"
00295 " tmpshell += const_wave_f[ifunc++] * ydist2 * xdist * zdist;\n"
00296 " tmpshell += const_wave_f[ifunc++] * ydist2 * ydist * zdist;\n"
00297 " tmpshell += const_wave_f[ifunc++] * xdist2 * zdist2;\n"
00298 " tmpshell += const_wave_f[ifunc++] * zdist2 * xdist * ydist;\n"
00299 " tmpshell += const_wave_f[ifunc++] * ydist2 * zdist2;\n"
00300 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * xdist;\n"
00301 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist * ydist;\n"
00302 " tmpshell += const_wave_f[ifunc++] * zdist2 * zdist2;\n"
00303 " value += tmpshell * contracted_gto;\n"
00304 );
00305 break;
00306
00307 }
00308 fprintf(ofp, "\n");
00309
00310 shell_counter++;
00311 }
00312 }
00313
00314 fprintf(ofp,
00315 " }\n"
00316 "\n"
00317 " // return either orbital density or orbital wavefunction amplitude \n"
00318 " if (density) { \n"
00319 );
00320
00321 if (jitlanguage == ORBITAL_JIT_CUDA) {
00322 fprintf(ofp, " orbitalgrid[outaddr] = copysignf(value*value, value);\n");
00323 } else if (jitlanguage == ORBITAL_JIT_OPENCL) {
00324 fprintf(ofp, " orbitalgrid[outaddr] = copysign(value*value, value);\n");
00325 }
00326
00327 fprintf(ofp,
00328 " } else { \n"
00329 " orbitalgrid[outaddr] = value; \n"
00330 " }\n"
00331 "}\n"
00332 );
00333
00334 if (ofp != stdout)
00335 fclose(ofp);
00336
00337 return 0;
00338 }
00339
00340
00341