00001 /*************************************************************************** 00002 *cr 00003 *cr (C) Copyright 1995-2019 The Board of Trustees of the 00004 *cr University of Illinois 00005 *cr All Rights Reserved 00006 *cr 00007 ***************************************************************************/ 00008 /*************************************************************************** 00009 * RCS INFORMATION: 00010 * 00011 * $RCSfile: Orbital_AVX2.C,v $ 00012 * $Author: johns $ $Locker: $ $State: Exp $ 00013 * $Revision: 1.1 $ $Date: 2021年02月16日 17:24:26 $ 00014 * 00015 ***************************************************************************/ 00021 // Due to differences in code generation between gcc/intelc/clang/msvc, we 00022 // don't have to check for a defined(__AVX2__) 00023 #if defined(VMDCPUDISPATCH) && defined(VMDUSEAVX2) 00024 00025 #include <immintrin.h> 00026 00027 #include <math.h> 00028 #include <stdio.h> 00029 #include "Orbital.h" 00030 #include "DrawMolecule.h" 00031 #include "utilities.h" 00032 #include "Inform.h" 00033 #include "WKFThreads.h" 00034 #include "WKFUtils.h" 00035 #include "ProfileHooks.h" 00036 00037 #define ANGS_TO_BOHR 1.88972612478289694072f 00038 00039 #if defined(__GNUC__) && ! defined(__INTEL_COMPILER) 00040 #define __align(X) __attribute__((aligned(X) )) 00041 #else 00042 #define __align(X) __declspec(align(X) ) 00043 #endif 00044 00045 #define MLOG2EF -1.44269504088896f 00046 00047 #if 0 00048 static void print_mm256_ps(__m256 v) { 00049 __attribute__((aligned(32))) float tmp[8]; // 32-byte aligned for AVX2 00050 _mm256_storeu_ps(&tmp[0], v); 00051 00052 printf("mm256: "); 00053 int i; 00054 for (i=0; i<8; i++) 00055 printf("%g ", tmp[i]); 00056 printf("\n"); 00057 } 00058 #endif 00059 00060 00061 // 00062 // John Stone, January 2021 00063 // 00064 // aexpfnxavx2() - AVX2 version of aexpfnx(). 00065 // 00066 00067 /* 00068 * Interpolating coefficients for linear blending of the 00069 * 3rd degree Taylor expansion of 2^x about 0 and -1. 00070 */ 00071 #define SCEXP0 1.0000000000000000f 00072 #define SCEXP1 0.6987082824680118f 00073 #define SCEXP2 0.2633174272827404f 00074 #define SCEXP3 0.0923611991471395f 00075 #define SCEXP4 0.0277520543324108f 00076 00077 /* for single precision float */ 00078 #define EXPOBIAS 127 00079 #define EXPOSHIFT 23 00080 00081 /* cutoff is optional, but can help avoid unnecessary work */ 00082 #define ACUTOFF -10 00083 00084 typedef union AVX2reg_t { 00085 __m256 f; // 8x float (AVX) 00086 __m256i i; // 8x 32-bit int (AVX2) 00087 } AVX2reg; 00088 00089 __m256 aexpfnxavx2(__m256 x) { 00090 __align(32) AVX2reg scal; 00091 scal.f = _mm256_cmp_ps(x, _mm256_set1_ps(ACUTOFF), _CMP_GE_OQ); // Is x within cutoff? 00092 // If all x are outside of cutoff, return 0s. 00093 if (_mm256_movemask_ps(scal.f) == 0) { 00094 return _mm256_set1_ps(0.0f); 00095 } 00096 // Otherwise, scal.f contains mask to be ANDed with the scale factor 00097 00098 /* 00099 * Convert base: exp(x) = 2^(N-d) where N is integer and 0 <= d < 1. 00100 * 00101 * Below we calculate n=N and x=-d, with "y" for temp storage, 00102 * calculate floor of x*log2(e) and subtract to get -d. 00103 */ 00104 __align(32) AVX2reg n; 00105 __m256 mb = _mm256_mul_ps(x, _mm256_set1_ps(MLOG2EF)); 00106 n.i = _mm256_cvttps_epi32(mb); 00107 __m256 mbflr = _mm256_cvtepi32_ps(n.i); 00108 __m256 d = _mm256_sub_ps(mbflr, mb); 00109 00110 // Approximate 2^{-d}, 0 <= d < 1, by interpolation. 00111 // Perform Horner's method to evaluate interpolating polynomial. 00112 __m256 y; 00113 y = _mm256_fmadd_ps(d, _mm256_set1_ps(SCEXP4), _mm256_set1_ps(SCEXP3)); 00114 y = _mm256_fmadd_ps(y, d, _mm256_set1_ps(SCEXP2)); 00115 y = _mm256_fmadd_ps(y, d, _mm256_set1_ps(SCEXP1)); 00116 y = _mm256_fmadd_ps(y, d, _mm256_set1_ps(SCEXP0)); 00117 00118 // Calculate 2^N exactly by directly manipulating floating point exponent, 00119 // then use it to scale y for the final result. 00120 n.i = _mm256_sub_epi32(_mm256_set1_epi32(EXPOBIAS), n.i); 00121 n.i = _mm256_slli_epi32(n.i, EXPOSHIFT); 00122 scal.f = _mm256_and_ps(scal.f, n.f); 00123 y = _mm256_mul_ps(y, scal.f); 00124 00125 return y; 00126 } 00127 00128 00129 // 00130 // AVX2 implementation for Xeons that don't have special fctn units 00131 // 00132 int evaluate_grid_avx2(int numatoms, 00133 const float *wave_f, const float *basis_array, 00134 const float *atompos, 00135 const int *atom_basis, 00136 const int *num_shells_per_atom, 00137 const int *num_prim_per_shell, 00138 const int *shell_types, 00139 const int *numvoxels, 00140 float voxelsize, 00141 const float *origin, 00142 int density, 00143 float * orbitalgrid) { 00144 if (!orbitalgrid) 00145 return -1; 00146 00147 int nx, ny, nz; 00148 __attribute__((aligned(32))) float sxdelta[8]; // 32-byte aligned for AVX2 00149 for (nx=0; nx<8; nx++) 00150 sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR; 00151 00152 // Calculate the value of the orbital at each gridpoint and store in 00153 // the current oribtalgrid array 00154 int numgridxy = numvoxels[0]*numvoxels[1]; 00155 for (nz=0; nz<numvoxels[2]; nz++) { 00156 float grid_x, grid_y, grid_z; 00157 grid_z = origin[2] + nz * voxelsize; 00158 for (ny=0; ny<numvoxels[1]; ny++) { 00159 grid_y = origin[1] + ny * voxelsize; 00160 int gaddrzy = ny*numvoxels[0] + nz*numgridxy; 00161 for (nx=0; nx<numvoxels[0]; nx+=8) { 00162 grid_x = origin[0] + nx * voxelsize; 00163 00164 // calculate the value of the wavefunction of the 00165 // selected orbital at the current grid point 00166 int at; 00167 int prim, shell; 00168 00169 // initialize value of orbital at gridpoint 00170 __m256 value = _mm256_set1_ps(0.0f); 00171 00172 // initialize the wavefunction and shell counters 00173 int ifunc = 0; 00174 int shell_counter = 0; 00175 00176 // loop over all the QM atoms 00177 for (at=0; at<numatoms; at++) { 00178 int maxshell = num_shells_per_atom[at]; 00179 int prim_counter = atom_basis[at]; 00180 00181 // calculate distance between grid point and center of atom 00182 float sxdist = (grid_x - atompos[3*at ])*ANGS_TO_BOHR; 00183 float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR; 00184 float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR; 00185 00186 float sydist2 = sydist*sydist; 00187 float szdist2 = szdist*szdist; 00188 float yzdist2 = sydist2 + szdist2; 00189 00190 __m256 xdelta = _mm256_load_ps(&sxdelta[0]); // aligned load 00191 __m256 xdist = _mm256_set1_ps(sxdist); 00192 xdist = _mm256_add_ps(xdist, xdelta); 00193 __m256 ydist = _mm256_set1_ps(sydist); 00194 __m256 zdist = _mm256_set1_ps(szdist); 00195 __m256 xdist2 = _mm256_mul_ps(xdist, xdist); 00196 __m256 ydist2 = _mm256_mul_ps(ydist, ydist); 00197 __m256 zdist2 = _mm256_mul_ps(zdist, zdist); 00198 __m256 dist2 = _mm256_set1_ps(yzdist2); 00199 dist2 = _mm256_add_ps(dist2, xdist2); 00200 00201 // loop over the shells belonging to this atom 00202 // XXX this is maybe a misnomer because in split valence 00203 // basis sets like 6-31G we have more than one basis 00204 // function per (valence-)shell and we are actually 00205 // looping over the individual contracted GTOs 00206 for (shell=0; shell < maxshell; shell++) { 00207 __m256 contracted_gto = _mm256_set1_ps(0.0f); 00208 00209 // Loop over the Gaussian primitives of this contracted 00210 // basis function to build the atomic orbital 00211 // 00212 // XXX there's a significant opportunity here for further 00213 // speedup if we replace the entire set of primitives 00214 // with the single gaussian that they are attempting 00215 // to model. This could give us another 6x speedup in 00216 // some of the common/simple cases. 00217 int maxprim = num_prim_per_shell[shell_counter]; 00218 int shelltype = shell_types[shell_counter]; 00219 for (prim=0; prim<maxprim; prim++) { 00220 // XXX pre-negate exponent value 00221 float exponent = -basis_array[prim_counter ]; 00222 float contract_coeff = basis_array[prim_counter + 1]; 00223 00224 // contracted_gto += contract_coeff * exp(-exponent*dist2); 00225 __m256 expval = _mm256_mul_ps(_mm256_set1_ps(exponent), dist2); 00226 // exp2f() equivalent required, use base-2 approximation 00227 __m256 retval = aexpfnxavx2(expval); 00228 contracted_gto = _mm256_fmadd_ps(_mm256_set1_ps(contract_coeff), retval, contracted_gto); 00229 00230 prim_counter += 2; 00231 } 00232 00233 /* multiply with the appropriate wavefunction coefficient */ 00234 __m256 tmpshell = _mm256_set1_ps(0.0f); 00235 switch (shelltype) { 00236 // use FMADD instructions 00237 case S_SHELL: 00238 value = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), contracted_gto, value); 00239 break; 00240 00241 case P_SHELL: 00242 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), xdist, tmpshell); 00243 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), ydist, tmpshell); 00244 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), zdist, tmpshell); 00245 value = _mm256_fmadd_ps(tmpshell, contracted_gto, value); 00246 break; 00247 00248 case D_SHELL: 00249 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), xdist2, tmpshell); 00250 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(xdist, ydist), tmpshell); 00251 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), ydist2, tmpshell); 00252 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(xdist, zdist), tmpshell); 00253 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(ydist, zdist), tmpshell); 00254 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), zdist2, tmpshell); 00255 value = _mm256_fmadd_ps(tmpshell, contracted_gto, value); 00256 break; 00257 00258 case F_SHELL: 00259 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(xdist2, xdist), tmpshell); 00260 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(xdist2, ydist), tmpshell); 00261 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(ydist2, xdist), tmpshell); 00262 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(ydist2, ydist), tmpshell); 00263 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(xdist2, zdist), tmpshell); 00264 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(_mm256_mul_ps(xdist, ydist), zdist), tmpshell); 00265 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(ydist2, zdist), tmpshell); 00266 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(zdist2, xdist), tmpshell); 00267 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(zdist2, ydist), tmpshell); 00268 tmpshell = _mm256_fmadd_ps(_mm256_set1_ps(wave_f[ifunc++]), _mm256_mul_ps(zdist2, zdist), tmpshell); 00269 value = _mm256_fmadd_ps(tmpshell, contracted_gto, value); 00270 break; 00271 00272 #if 0 00273 default: 00274 // avoid unnecessary branching and minimize use of pow() 00275 int i, j; 00276 float xdp, ydp, zdp; 00277 float xdiv = 1.0f / xdist; 00278 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) { 00279 int imax = shelltype - j; 00280 for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) { 00281 tmpshell += wave_f[ifunc++] * xdp * ydp * zdp; 00282 } 00283 } 00284 value += tmpshell * contracted_gto; 00285 #endif 00286 } // end switch 00287 00288 shell_counter++; 00289 } // end shell 00290 } // end atom 00291 00292 // return either orbital density or orbital wavefunction amplitude 00293 if (density) { 00294 __m256 mask = _mm256_cmp_ps(value, _mm256_set1_ps(0.0f), _CMP_LT_OQ); 00295 __m256 sqdensity = _mm256_mul_ps(value, value); 00296 __m256 orbdensity = sqdensity; 00297 __m256 nsqdensity = _mm256_and_ps(sqdensity, mask); 00298 orbdensity = _mm256_sub_ps(orbdensity, nsqdensity); 00299 orbdensity = _mm256_sub_ps(orbdensity, nsqdensity); 00300 _mm256_storeu_ps(&orbitalgrid[gaddrzy + nx], orbdensity); 00301 } else { 00302 _mm256_storeu_ps(&orbitalgrid[gaddrzy + nx], value); 00303 } 00304 } 00305 } 00306 } 00307 00308 // prevent x86 AVX2 clock rate limiting performance loss due to 00309 // false dependence on upper vector register state for scalar or 00310 // SSE instructions executing after an AVX2 instruction has written 00311 // an upper register. 00312 _mm256_zeroupper(); 00313 00314 return 0; 00315 } 00316 00317 #endif 00318 00319