Main Page   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Related Pages  

Orbital_NEON.C

Go to the documentation of this file.
00001 /***************************************************************************
00002  *cr                                                                       
00003  *cr            (C) Copyright 1995-2019 The Board of Trustees of the           
00004  *cr                        University of Illinois                       
00005  *cr                         All Rights Reserved                        
00006  *cr                                                                   
00007  ***************************************************************************/
00008 /***************************************************************************
00009  * RCS INFORMATION:
00010  *
00011  *      $RCSfile: Orbital_NEON.C,v $
00012  *      $Author: johns $        $Locker:  $             $State: Exp $
00013  *      $Revision: 1.3 $        $Date: 2022/05/06 03:43:05 $
00014  *
00015  ***************************************************************************/
00021 //
00022 // Notes:
00023 //   ARM intrinsics documentation:
00024 //     https://developer.arm.com/architectures/instruction-sets/intrinsics
00025 //
00026 
00027 
00028 // Due to differences in code generation between gcc/intelc/clang/msvc, we
00029 // don't have to check for a defined(__NEON__)
00030 #if defined(VMDCPUDISPATCH) && defined(VMDUSENEON) 
00031 #include <arm_neon.h>
00032 
00033 #include <math.h>
00034 #include <stdio.h>
00035 #include "Orbital.h"
00036 #include "DrawMolecule.h"
00037 #include "utilities.h"
00038 #include "Inform.h"
00039 #include "WKFThreads.h"
00040 #include "WKFUtils.h"
00041 #include "ProfileHooks.h"
00042 
00043 #define ANGS_TO_BOHR 1.88972612478289694072f
00044 
00045 // #if defined(__GNUC__) 
00046 #define __align(X)  __attribute__((aligned(X) ))
00047 // #endif
00048 
00049 #define MLOG2EF    -1.44269504088896f
00050 
00051 #if 0
00052 static void print_float32x4_t(float32x4_t v) {
00053   __attribute__((aligned(16))) float tmp[4]; // 16-byte aligned for NEON
00054   vst1q_f32(&tmp[0], v);
00055 
00056   printf("print_float32x4_t: ");
00057   int i;
00058   for (i=0; i<4; i++)
00059     printf("%g ", tmp[i]);
00060   printf("\n");
00061 }
00062 
00063 static void print_int32x4_t(int32x4_t v) {
00064   __attribute__((aligned(16))) int tmp[4]; // 16-byte aligned for NEON
00065   vst1q_s32(&tmp[0], v);
00066 
00067   printf("print_int32x4_t: ");
00068   int i;
00069   for (i=0; i<4; i++)
00070     printf("%d ", tmp[i]);
00071   printf("\n");
00072 }
00073 
00074 static void print_hex32x4_t(int32x4_t v) {
00075   __attribute__((aligned(16))) int tmp[4]; // 16-byte aligned for NEON
00076   vst1q_s32(&tmp[0], v);
00077 
00078   printf("print_hex32x4_t: ");
00079   int i;
00080   for (i=0; i<4; i++)
00081     printf("%08x ", tmp[i]);
00082   printf("\n");
00083 }
00084 #endif
00085 
00086 
00087 //
00088 // John Stone, February 2021
00089 //
00090 // aexpfnxneon() - NEON version of aexpfnx().
00091 //
00092 
00093 /*
00094  * Interpolating coefficients for linear blending of the
00095  * 3rd degree Taylor expansion of 2^x about 0 and -1.
00096  */
00097 #define SCEXP0     1.0000000000000000f
00098 #define SCEXP1     0.6987082824680118f
00099 #define SCEXP2     0.2633174272827404f
00100 #define SCEXP3     0.0923611991471395f
00101 #define SCEXP4     0.0277520543324108f
00102 
00103 /* for single precision float */
00104 #define EXPOBIAS   127
00105 #define EXPOSHIFT   23
00106 
00107 /* cutoff is optional, but can help avoid unnecessary work */
00108 #define ACUTOFF    -10
00109 
00110 typedef union NEONreg_t {
00111   float32x4_t f;  // 4x float (NEON)
00112   int32x4_t   i;  // 4x 32-bit int (NEON)
00113 } NEONreg;
00114 
00115 float32x4_t aexpfnxneon(float32x4_t x) {
00116   __align(16) NEONreg scal;
00117 
00118 #if 1
00119   // NEON seems to lack a convenient way to test if any lane was true, so
00120   // we use a different approach, and perform a horizontal maximum, comparing
00121   // the result of that against the cutoff to see if all four values are 
00122   // below the the cutoff, and to early exit returning zeros in that case.
00123   // If all x are outside of cutoff, return 0s.  There may be a better 
00124   // early exit scheme here if we dig and find some more useful NEON 
00125   // instructions.  This block of code is reverse-ordered from the x86 
00126   // variants as a result of the different scheme used here.
00127   float32x2_t tmp;
00128   tmp = vpmax_f32(vget_low_f32(x), vget_high_f32(x));
00129   tmp = vpmax_f32(tmp, tmp);
00130   float vmax = vget_lane_f32(tmp, 0);
00131   if (vmax < ACUTOFF) {
00132     return vdupq_n_f32(0.0f);
00133   }
00134 #endif
00135   // Otherwise, scal.f contains mask to be ANDed with the scale factor
00136   scal.f = vcvtq_f32_u32(vcgeq_f32(x, vdupq_n_f32(ACUTOFF)));  // Is x within cutoff?
00137 
00138   /*
00139    * Convert base:  exp(x) = 2^(N-d) where N is integer and 0 <= d < 1.
00140    *
00141    * Below we calculate n=N and x=-d, with "y" for temp storage,
00142    * calculate floor of x*log2(e) and subtract to get -d.
00143    */
00144   __align(16) NEONreg n;
00145   float32x4_t mb = vmulq_f32(x, vdupq_n_f32(MLOG2EF));
00146   n.i = vcvtq_s32_f32(mb);
00147   float32x4_t mbflr = vcvtq_f32_s32(n.i);
00148   float32x4_t d = vsubq_f32(mbflr, mb);
00149 
00150   // Approximate 2^{-d}, 0 <= d < 1, by interpolation.
00151   // Perform Horner's method to evaluate interpolating polynomial.
00152   float32x4_t y;
00153 #if __ARM_FEATURE_FMA
00154   y = vfmaq_f32(vdupq_n_f32(SCEXP3), vdupq_n_f32(SCEXP4), d);
00155   y = vfmaq_f32(vdupq_n_f32(SCEXP2), d, y);
00156   y = vfmaq_f32(vdupq_n_f32(SCEXP1), d, y);
00157   y = vfmaq_f32(vdupq_n_f32(SCEXP0), d, y);
00158 #else
00159   y = vmulq_f32(d, vdupq_n_f32(SCEXP4));    /* for x^4 term */
00160   y = vaddq_f32(y, vdupq_n_f32(SCEXP3));    /* for x^3 term */
00161   y = vmulq_f32(y, d);
00162   y = vaddq_f32(y, vdupq_n_f32(SCEXP2));    /* for x^2 term */
00163   y = vmulq_f32(y, d);
00164   y = vaddq_f32(y, vdupq_n_f32(SCEXP1));    /* for x^1 term */
00165   y = vmulq_f32(y, d);
00166   y = vaddq_f32(y, vdupq_n_f32(SCEXP0));    /* for x^0 term */
00167 #endif
00168 
00169   // Calculate 2^N exactly by directly manipulating floating point exponent,
00170   // then use it to scale y for the final result.
00171   n.i = vsubq_s32(vdupq_n_s32(EXPOBIAS), n.i);
00172   n.i = vshlq_s32(n.i, vdupq_n_s32(EXPOSHIFT));
00173   scal.i = vandq_s32(scal.i, n.i);
00174   y = vmulq_f32(y, scal.f);
00175 
00176   return y;
00177 }
00178 
00179 
00180 //
00181 // NEON implementation for Xeons that don't have special fctn units
00182 //
00183 int evaluate_grid_neon(int numatoms,
00184                           const float *wave_f, const float *basis_array,
00185                           const float *atompos,
00186                           const int *atom_basis,
00187                           const int *num_shells_per_atom,
00188                           const int *num_prim_per_shell,
00189                           const int *shell_types,
00190                           const int *numvoxels,
00191                           float voxelsize,
00192                           const float *origin,
00193                           int density,
00194                           float * orbitalgrid) {
00195   if (!orbitalgrid)
00196     return -1;
00197 
00198   int nx, ny, nz;
00199   __attribute__((aligned(16))) float sxdelta[4]; // 16-byte aligned for NEON
00200   for (nx=0; nx<4; nx++) 
00201     sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR;
00202 
00203   // Calculate the value of the orbital at each gridpoint and store in 
00204   // the current oribtalgrid array
00205   int numgridxy = numvoxels[0]*numvoxels[1];
00206   for (nz=0; nz<numvoxels[2]; nz++) {
00207     float grid_x, grid_y, grid_z;
00208     grid_z = origin[2] + nz * voxelsize;
00209     for (ny=0; ny<numvoxels[1]; ny++) {
00210       grid_y = origin[1] + ny * voxelsize;
00211       int gaddrzy = ny*numvoxels[0] + nz*numgridxy;
00212       for (nx=0; nx<numvoxels[0]; nx+=4) {
00213         grid_x = origin[0] + nx * voxelsize;
00214 
00215         // calculate the value of the wavefunction of the
00216         // selected orbital at the current grid point
00217         int at;
00218         int prim, shell;
00219 
00220         // initialize value of orbital at gridpoint
00221         float32x4_t value = vdupq_n_f32(0.0f);
00222 
00223         // initialize the wavefunction and shell counters
00224         int ifunc = 0; 
00225         int shell_counter = 0;
00226 
00227         // loop over all the QM atoms
00228         for (at=0; at<numatoms; at++) {
00229           int maxshell = num_shells_per_atom[at];
00230           int prim_counter = atom_basis[at];
00231 
00232           // calculate distance between grid point and center of atom
00233           float sxdist = (grid_x - atompos[3*at  ])*ANGS_TO_BOHR;
00234           float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR;
00235           float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR;
00236 
00237           float sydist2 = sydist*sydist;
00238           float szdist2 = szdist*szdist;
00239           float yzdist2 = sydist2 + szdist2;
00240 
00241           float32x4_t xdelta = vld1q_f32(&sxdelta[0]); // aligned load
00242           float32x4_t xdist  = vdupq_n_f32(sxdist);
00243           xdist = vaddq_f32(xdist, xdelta);
00244           float32x4_t ydist  = vdupq_n_f32(sydist);
00245           float32x4_t zdist  = vdupq_n_f32(szdist);
00246           float32x4_t xdist2 = vmulq_f32(xdist, xdist);
00247           float32x4_t ydist2 = vmulq_f32(ydist, ydist);
00248           float32x4_t zdist2 = vmulq_f32(zdist, zdist);
00249           float32x4_t dist2  = vdupq_n_f32(yzdist2); 
00250           dist2 = vaddq_f32(dist2, xdist2);
00251  
00252           // loop over the shells belonging to this atom
00253           // XXX this is maybe a misnomer because in split valence
00254           //     basis sets like 6-31G we have more than one basis
00255           //     function per (valence-)shell and we are actually
00256           //     looping over the individual contracted GTOs
00257           for (shell=0; shell < maxshell; shell++) {
00258             float32x4_t contracted_gto = vdupq_n_f32(0.0f);
00259 
00260             // Loop over the Gaussian primitives of this contracted 
00261             // basis function to build the atomic orbital
00262             // 
00263             // XXX there's a significant opportunity here for further
00264             //     speedup if we replace the entire set of primitives
00265             //     with the single gaussian that they are attempting 
00266             //     to model.  This could give us another 6x speedup in 
00267             //     some of the common/simple cases.
00268             int maxprim = num_prim_per_shell[shell_counter];
00269             int shelltype = shell_types[shell_counter];
00270             for (prim=0; prim<maxprim; prim++) {
00271               // XXX pre-negate exponent value
00272               float exponent       = -basis_array[prim_counter    ];
00273               float contract_coeff =  basis_array[prim_counter + 1];
00274 
00275               // contracted_gto += contract_coeff * exp(-exponent*dist2);
00276               float32x4_t expval = vmulq_f32(vdupq_n_f32(exponent), dist2);
00277               // exp2f() equivalent required, use base-2 approximation
00278               float32x4_t retval = aexpfnxneon(expval);
00279               contracted_gto = vfmaq_f32(contracted_gto, retval, vdupq_n_f32(contract_coeff));
00280 
00281               prim_counter += 2;
00282             }
00283 
00284             /* multiply with the appropriate wavefunction coefficient */
00285             float32x4_t tmpshell = vdupq_n_f32(0.0f);
00286             switch (shelltype) {
00287               // use FMADD instructions
00288               case S_SHELL:
00289                 value = vfmaq_f32(value, contracted_gto, vdupq_n_f32(wave_f[ifunc++]));
00290                 break;
00291 
00292               case P_SHELL:
00293                 tmpshell = vfmaq_f32(tmpshell, xdist, vdupq_n_f32(wave_f[ifunc++]));
00294                 tmpshell = vfmaq_f32(tmpshell, ydist, vdupq_n_f32(wave_f[ifunc++]));
00295                 tmpshell = vfmaq_f32(tmpshell, zdist, vdupq_n_f32(wave_f[ifunc++]));
00296                 value = vfmaq_f32(value, contracted_gto, tmpshell);
00297                 break;
00298 
00299               case D_SHELL:
00300                 tmpshell = vfmaq_f32(tmpshell, xdist2, vdupq_n_f32(wave_f[ifunc++]));
00301                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(xdist, ydist), vdupq_n_f32(wave_f[ifunc++]));
00302                 tmpshell = vfmaq_f32(tmpshell, ydist2, vdupq_n_f32(wave_f[ifunc++]));
00303                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(xdist, zdist), vdupq_n_f32(wave_f[ifunc++]));
00304                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(ydist, zdist), vdupq_n_f32(wave_f[ifunc++]));
00305                 tmpshell = vfmaq_f32(tmpshell, zdist2, vdupq_n_f32(wave_f[ifunc++]));
00306                 value = vfmaq_f32(value, contracted_gto, tmpshell);
00307                 break;
00308 
00309               case F_SHELL:
00310                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(xdist2, xdist), vdupq_n_f32(wave_f[ifunc++]));
00311                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(xdist2, ydist), vdupq_n_f32(wave_f[ifunc++]));
00312                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(ydist2, xdist), vdupq_n_f32(wave_f[ifunc++]));
00313                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(ydist2, ydist), vdupq_n_f32(wave_f[ifunc++]));
00314                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(xdist2, zdist), vdupq_n_f32(wave_f[ifunc++]));
00315                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(vmulq_f32(xdist, ydist), zdist), vdupq_n_f32(wave_f[ifunc++]));
00316                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(ydist2, zdist), vdupq_n_f32(wave_f[ifunc++]));
00317                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(zdist2, xdist), vdupq_n_f32(wave_f[ifunc++]));
00318                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(zdist2, ydist), vdupq_n_f32(wave_f[ifunc++]));
00319                 tmpshell = vfmaq_f32(tmpshell, vmulq_f32(zdist2, zdist), vdupq_n_f32(wave_f[ifunc++]));
00320                 value = vfmaq_f32(value, contracted_gto, tmpshell);
00321                 break;
00322  
00323 #if 0
00324               default:
00325                 // avoid unnecessary branching and minimize use of pow()
00326                 int i, j; 
00327                 float xdp, ydp, zdp;
00328                 float xdiv = 1.0f / xdist;
00329                 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) {
00330                   int imax = shelltype - j; 
00331                   for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) {
00332                     tmpshell += wave_f[ifunc++] * xdp * ydp * zdp;
00333                   }
00334                 }
00335                 value += tmpshell * contracted_gto;
00336 #endif
00337             } // end switch
00338 
00339             shell_counter++;
00340           } // end shell
00341         } // end atom
00342 
00343         // return either orbital density or orbital wavefunction amplitude
00344         if (density) {
00345           float32x4_t mask = vcvtq_f32_u32(vcltq_f32(value, vdupq_n_f32(0.0f)));
00346           float32x4_t sqdensity = vmulq_f32(value, value);
00347           float32x4_t orbdensity = sqdensity;
00348           float32x4_t nsqdensity = vmulq_f32(sqdensity, mask);
00349           orbdensity = vsubq_f32(orbdensity, nsqdensity);
00350           orbdensity = vsubq_f32(orbdensity, nsqdensity);
00351           vst1q_f32(&orbitalgrid[gaddrzy + nx], orbdensity);
00352         } else {
00353           vst1q_f32(&orbitalgrid[gaddrzy + nx], value);
00354         }
00355       }
00356     }
00357   }
00358 
00359   return 0;
00360 }
00361 
00362 #endif
00363 
00364 

Generated on Tue Apr 30 02:44:35 2024 for VMD (current) by doxygen1.2.14 written by Dimitri van Heesch, © 1997-2002