Main Page   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Related Pages  

Orbital_AVX512ER.C

Go to the documentation of this file.
00001 /***************************************************************************
00002  *cr                                                                       
00003  *cr            (C) Copyright 1995-2019 The Board of Trustees of the           
00004  *cr                        University of Illinois                       
00005  *cr                         All Rights Reserved                        
00006  *cr                                                                   
00007  ***************************************************************************/
00008 /***************************************************************************
00009  * RCS INFORMATION:
00010  *
00011  *      $RCSfile: Orbital_AVX512ER.C,v $
00012  *      $Author: johns $        $Locker:  $             $State: Exp $
00013  *      $Revision: 1.3 $        $Date: 2020/10/27 04:18:28 $
00014  *
00015  ***************************************************************************/
00021 // Due to differences in code generation between gcc/intelc/clang/msvc, we
00022 // don't have to check for a (defined(__AVX512F__) && defined(__AVX512ER__))
00023 #if defined(VMDCPUDISPATCH) && defined(VMDUSEAVX512) 
00024 
00025 #include <immintrin.h>
00026 
00027 #include <math.h>
00028 #include <stdio.h>
00029 #include "Orbital.h"
00030 #include "DrawMolecule.h"
00031 #include "utilities.h"
00032 #include "Inform.h"
00033 #include "WKFThreads.h"
00034 #include "WKFUtils.h"
00035 #include "ProfileHooks.h"
00036 
00037 #define ANGS_TO_BOHR 1.88972612478289694072f
00038 
00039 #if defined(__GNUC__) && ! defined(__INTEL_COMPILER)
00040 #define __align(X)  __attribute__((aligned(X) ))
00041 #else
00042 #define __align(X) __declspec(align(X) )
00043 #endif
00044 
00045 #define MLOG2EF    -1.44269504088896f
00046 
00047 #if 0
00048 static void print_mm512_ps(__m512 v) {
00049   __attribute__((aligned(64))) float tmp[16]; // 64-byte aligned for AVX512
00050   _mm512_storeu_ps(&tmp[0], v);
00051 
00052   printf("mm512: ");
00053   int i;
00054   for (i=0; i<16; i++) 
00055     printf("%g ", tmp[i]);
00056   printf("\n");
00057 }
00058 #endif
00059 
00060 
00061 
00062 //
00063 // AVX-512ER implementation for Xeon Phi w/ special fctn units
00064 //
00065 int evaluate_grid_avx512er(int numatoms,
00066                            const float *wave_f, const float *basis_array,
00067                            const float *atompos,
00068                            const int *atom_basis,
00069                            const int *num_shells_per_atom,
00070                            const int *num_prim_per_shell,
00071                            const int *shell_types,
00072                            const int *numvoxels,
00073                            float voxelsize,
00074                            const float *origin,
00075                            int density,
00076                            float * orbitalgrid) {
00077   if (!orbitalgrid)
00078     return -1;
00079 
00080   int nx, ny, nz;
00081   __attribute__((aligned(64))) float sxdelta[16]; // 64-byte aligned for AVX512
00082   for (nx=0; nx<16; nx++) 
00083     sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR;
00084 
00085   // Calculate the value of the orbital at each gridpoint and store in 
00086   // the current oribtalgrid array
00087   int numgridxy = numvoxels[0]*numvoxels[1];
00088   for (nz=0; nz<numvoxels[2]; nz++) {
00089     float grid_x, grid_y, grid_z;
00090     grid_z = origin[2] + nz * voxelsize;
00091     for (ny=0; ny<numvoxels[1]; ny++) {
00092       grid_y = origin[1] + ny * voxelsize;
00093       int gaddrzy = ny*numvoxels[0] + nz*numgridxy;
00094       for (nx=0; nx<numvoxels[0]; nx+=16) {
00095         grid_x = origin[0] + nx * voxelsize;
00096 
00097         // calculate the value of the wavefunction of the
00098         // selected orbital at the current grid point
00099         int at;
00100         int prim, shell;
00101 
00102         // initialize value of orbital at gridpoint
00103         __m512 value = _mm512_set1_ps(0.0f);
00104 
00105         // initialize the wavefunction and shell counters
00106         int ifunc = 0; 
00107         int shell_counter = 0;
00108 
00109         // loop over all the QM atoms
00110         for (at=0; at<numatoms; at++) {
00111           int maxshell = num_shells_per_atom[at];
00112           int prim_counter = atom_basis[at];
00113 
00114           // calculate distance between grid point and center of atom
00115           float sxdist = (grid_x - atompos[3*at  ])*ANGS_TO_BOHR;
00116           float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR;
00117           float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR;
00118 
00119           float sydist2 = sydist*sydist;
00120           float szdist2 = szdist*szdist;
00121           float yzdist2 = sydist2 + szdist2;
00122 
00123           __m512 xdelta = _mm512_load_ps(&sxdelta[0]); // aligned load
00124           __m512 xdist  = _mm512_set1_ps(sxdist);
00125           xdist = _mm512_add_ps(xdist, xdelta);
00126           __m512 ydist  = _mm512_set1_ps(sydist);
00127           __m512 zdist  = _mm512_set1_ps(szdist);
00128           __m512 xdist2 = _mm512_mul_ps(xdist, xdist);
00129           __m512 ydist2 = _mm512_mul_ps(ydist, ydist);
00130           __m512 zdist2 = _mm512_mul_ps(zdist, zdist);
00131           __m512 dist2  = _mm512_set1_ps(yzdist2); 
00132           dist2 = _mm512_add_ps(dist2, xdist2);
00133  
00134           // loop over the shells belonging to this atom
00135           // XXX this is maybe a misnomer because in split valence
00136           //     basis sets like 6-31G we have more than one basis
00137           //     function per (valence-)shell and we are actually
00138           //     looping over the individual contracted GTOs
00139           for (shell=0; shell < maxshell; shell++) {
00140             __m512 contracted_gto = _mm512_set1_ps(0.0f);
00141 
00142             // Loop over the Gaussian primitives of this contracted 
00143             // basis function to build the atomic orbital
00144             // 
00145             // XXX there's a significant opportunity here for further
00146             //     speedup if we replace the entire set of primitives
00147             //     with the single gaussian that they are attempting 
00148             //     to model.  This could give us another 6x speedup in 
00149             //     some of the common/simple cases.
00150             int maxprim = num_prim_per_shell[shell_counter];
00151             int shelltype = shell_types[shell_counter];
00152             for (prim=0; prim<maxprim; prim++) {
00153               // XXX pre-negate exponent value
00154               float exponent       = -basis_array[prim_counter    ];
00155               float contract_coeff =  basis_array[prim_counter + 1];
00156 
00157               // contracted_gto += contract_coeff * exp(-exponent*dist2);
00158 #if 1
00159               __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent * MLOG2EF), dist2);
00160               // expf() equivalent required, use base-2 AVX-512ER instructions
00161               __m512 retval = _mm512_exp2a23_ps(expval);
00162               contracted_gto = _mm512_fmadd_ps(_mm512_set1_ps(contract_coeff), retval, contracted_gto);
00163 #else
00164               __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent), dist2);
00165               // expf() equivalent required, use base-2 AVX-512ER instructions
00166               expval = _mm512_mul_ps(expval, _mm512_set1_ps(MLOG2EF));
00167               __m512 retval = _mm512_exp2a23_ps(expval);
00168               __m512 ctmp = _mm512_mul_ps(_mm512_set1_ps(contract_coeff), retval);
00169               contracted_gto = _mm512_add_ps(contracted_gto, ctmp);
00170 #endif
00171 
00172               prim_counter += 2;
00173             }
00174 
00175             /* multiply with the appropriate wavefunction coefficient */
00176             __m512 tmpshell = _mm512_set1_ps(0.0f);
00177             switch (shelltype) {
00178               // use FMADD instructions
00179               case S_SHELL:
00180                 value = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), contracted_gto, value);
00181                 break;
00182 
00183               case P_SHELL:
00184                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist, tmpshell);
00185                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist, tmpshell);
00186                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist, tmpshell);
00187                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00188                 break;
00189 
00190               case D_SHELL:
00191                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist2, tmpshell);
00192                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, ydist), tmpshell);
00193                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist2, tmpshell);
00194                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, zdist), tmpshell);
00195                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist, zdist), tmpshell);
00196                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist2, tmpshell);
00197                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00198                 break;
00199 
00200               case F_SHELL:
00201                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, xdist), tmpshell);
00202                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, ydist), tmpshell);
00203                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, xdist), tmpshell);
00204                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, ydist), tmpshell);
00205                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, zdist), tmpshell);
00206                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(_mm512_mul_ps(xdist, ydist), zdist), tmpshell);
00207                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, zdist), tmpshell);
00208                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, xdist), tmpshell);
00209                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, ydist), tmpshell);
00210                 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, zdist), tmpshell);
00211                 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00212                 break;
00213 
00214  
00215 #if 0
00216               default:
00217                 // avoid unnecessary branching and minimize use of pow()
00218                 int i, j; 
00219                 float xdp, ydp, zdp;
00220                 float xdiv = 1.0f / xdist;
00221                 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) {
00222                   int imax = shelltype - j; 
00223                   for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) {
00224                     tmpshell += wave_f[ifunc++] * xdp * ydp * zdp;
00225                   }
00226                 }
00227                 value += tmpshell * contracted_gto;
00228 #endif
00229             } // end switch
00230 
00231             shell_counter++;
00232           } // end shell
00233         } // end atom
00234 
00235         // return either orbital density or orbital wavefunction amplitude
00236         if (density) {
00237           __mmask16 mask = _mm512_cmplt_ps_mask(value, _mm512_set1_ps(0.0f));
00238           __m512 sqdensity = _mm512_mul_ps(value, value);
00239           __m512 orbdensity = _mm512_mask_mul_ps(sqdensity, mask, sqdensity,
00240                                                  _mm512_set1_ps(-1.0f));
00241           _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], orbdensity);
00242         } else {
00243           _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], value);
00244         }
00245       }
00246     }
00247   }
00248 
00249   // XXX note this is costly on Xeon Phi, but since it's a dead platform,
00250   // we'll write this for the benefit of a someday Xeon that supports the
00251   // Exponential/Reciprocal AVX-512ER instruction subset...
00252   //
00253   // Prevent x86 AVX-512 clock rate limiting performance loss due to 
00254   // false dependence on upper vector register state for scalar or 
00255   // SSE instructions executing after an AVX-512 instruction has written
00256   // an upper register. 
00257   _mm256_zeroupper();
00258 
00259   return 0;
00260 }
00261 
00262 #endif
00263 
00264 

Generated on Mon Apr 15 02:45:43 2024 for VMD (current) by doxygen1.2.14 written by Dimitri van Heesch, © 1997-2002