3
4 //#define _DEBUG
5
7 {
8 namespace blas_lapack
9 {
10
11 // whether we are using the native blas-lapack library
12 static bool native_blas_lapack = true;
14 void set_native(
bool native) { native_blas_lapack = native; }
15
16 namespace generic
17 {
18
20
22
23 // Batched inversion ckecking
24 //---------------------------------------------------
25 template <typename EigenMatrix, typename Float>
26 void invertEigen(std::complex<Float> *A_eig, std::complex<Float> *Ainv_eig,
int n, uint64_t batch)
27 {
28 EigenMatrix res = EigenMatrix::Zero(n, n);
29 EigenMatrix inv = EigenMatrix::Zero(n, n);
30 for (int j = 0; j < n; j++) {
31 for (int k = 0; k < n; k++) { res(k, j) = A_eig[batch * n * n + j * n + k]; }
32 }
33
34 inv = res.inverse();
35
36 for (int j = 0; j < n; j++) {
37 for (int k = 0; k < n; k++) { Ainv_eig[batch * n * n + j * n + k] = inv(k, j); }
38 }
39
40 // Check result:
41 #ifdef _DEBUG
42 EigenMatrix unit = EigenMatrix::Identity(n, n);
43 EigenMatrix prod = res * inv;
44 Float L2norm = ((prod - unit).
norm() / (n * n));
45 printfQuda(
"Eigen: Norm of (A * Ainv - I) batch %lu = %e\n", batch, L2norm);
46 #endif
47 }
48 //---------------------------------------------------
49
50 // Batched Inversions
51 //---------------------------------------------------
54 {
56 printfQuda(
"BatchInvertMatrix (generic - Eigen): Nc = %d, batch = %lu\n", n, batch);
57
58 size_t size = 2 * n * n * batch *
prec;
62
65 gettimeofday(&
start, NULL);
66
68 std::complex<float> *A_eig = (std::complex<float> *)A_h;
69 std::complex<float> *Ainv_eig = (std::complex<float> *)Ainv_h;
70
71 #ifdef _OPENMP
72 #pragma omp parallel for
73 #endif
74 for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcf, float>(A_eig, Ainv_eig, n, i); }
77 std::complex<double> *A_eig = (std::complex<double> *)A_h;
78 std::complex<double> *Ainv_eig = (std::complex<double> *)Ainv_h;
79
80 #ifdef _OPENMP
81 #pragma omp parallel for
82 #endif
83 for (uint64_t i = 0; i < batch; i++) { invertEigen<MatrixXcd, double>(A_eig, Ainv_eig, n, i); }
85 } else {
86 errorQuda(
"%s not implemented for precision = %d", __func__,
prec);
87 }
88
89 gettimeofday(&
stop, NULL);
92 double timeh = dsh + 0.000001 * dush;
93
95 int threads = 1;
96 #ifdef _OPENMP
97 threads = omp_get_num_threads();
98 #endif
99 printfQuda(
"CPU: Batched matrix inversion completed in %f seconds using %d threads with GFLOPS = %f\n", timeh,
100 threads, 1e-9 *
flops / timeh);
101 }
102
106 qudaMemcpy((
void *)Ainv, Ainv_h, size, cudaMemcpyHostToDevice);
107 }
108
110 }
111
112 // Srided Batched GEMM helpers
113 //--------------------------------------------------------------------------
114 template <typename EigenMat, typename T>
115 void fillArray(EigenMat &EigenArr, T *arr,
int rows,
int cols,
int ld,
int offset,
bool fill_eigen)
116 {
117 int counter = offset;
118 for (int i = 0; i < rows; i++) {
119 for (int j = 0; j < cols; j++) {
120 if (fill_eigen)
121 EigenArr(i, j) = arr[counter];
122 else
123 arr[counter] = EigenArr(i, j);
124 counter++;
125 }
126 counter += (ld - cols);
127 }
128 }
129
130 template <typename EigenMat, typename T>
131 void GEMM(
void *A_h,
void *B_h,
void *C_h, T alpha, T beta,
int max_stride,
QudaBLASParam &blas_param)
132 {
133 // Problem parameters
134 int m = blas_param.
m;
135 int n = blas_param.
n;
136 int k = blas_param.
k;
137 int lda = blas_param.
lda;
138 int ldb = blas_param.
ldb;
139 int ldc = blas_param.
ldc;
140
141 // If the user did not set any stride values, we default them to 1
142 // as batch size 0 is an option.
150
151 // Number of data between batches
152 unsigned int A_batch_size = blas_param.
lda * blas_param.
k;
154 unsigned int B_batch_size = blas_param.
ldb * blas_param.
n;
156 unsigned int C_batch_size = blas_param.
ldc * blas_param.
n;
157
158 T *A_ptr = (T *)(&A_h)[0];
159 T *B_ptr = (T *)(&B_h)[0];
160 T *C_ptr = (T *)(&C_h)[0];
161
162 // Eigen objects to store data
163 EigenMat Amat = EigenMat::Zero(m, k);
164 EigenMat Bmat = EigenMat::Zero(k, n);
165 EigenMat Cmat = EigenMat::Zero(m, n);
166
167 for (int batch = 0; batch < batches; batch += max_stride) {
168
169 // Populate Eigen objects
170 fillArray<EigenMat, T>(Amat, A_ptr, m, k, lda, a_offset, true);
171 fillArray<EigenMat, T>(Bmat, B_ptr, k, n, ldb, b_offset, true);
172 fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset, true);
173
174 // Apply op(A) and op(B)
180 }
181
187 }
188
189 // Perform GEMM using Eigen
190 Cmat = alpha * Amat * Bmat + beta * Cmat;
191
192 // Write back to the C array
193 fillArray<EigenMat, T>(Cmat, C_ptr, m, n, ldc, c_offset, false);
194
195 a_offset += A_batch_size * a_stride;
196 b_offset += B_batch_size * b_stride;
197 c_offset += C_batch_size * c_stride;
198 }
199 }
200 //---------------------------------------------------
201
202 // Strided Batched GEMM
203 //---------------------------------------------------
206 {
209 gettimeofday(&
start, NULL);
210
211 // Sanity checks on parameters
212 //-------------------------------------------------------------------------
213 // If the user passes non positive M,N, or K, we error out
214 int min_dim = std::min(blas_param.
m, std::min(blas_param.
n, blas_param.
k));
215 if (min_dim <= 0) {
216 errorQuda(
"BLAS dims must be positive: m=%d, n=%d, k=%d", blas_param.
m, blas_param.
n, blas_param.
k);
217 }
218
219 // If the user passes a negative stride, we error out as this has no meaning.
221 if (min_stride < 0) {
222 errorQuda(
"BLAS strides must be positive or zero: a_stride=%d, b_stride=%d, c_stride=%d", blas_param.
a_stride,
224 }
225
226 // If the user passes a negative offset, we error out as this has no meaning.
228 if (min_offset < 0) {
229 errorQuda(
"BLAS offsets must be positive or zero: a_offset=%d, b_offset=%d, c_offset=%d", blas_param.
a_offset,
231 }
232
233 // If the batch value is non-positve, we error out
235
236 // Leading dims are dependendent on the matrix op type.
239 if (blas_param.
lda < std::max(1, blas_param.
m))
240 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
241 } else {
242 if (blas_param.
lda < std::max(1, blas_param.
k))
243 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
244 }
245
247 if (blas_param.
ldb < std::max(1, blas_param.
k))
248 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
249 } else {
250 if (blas_param.
ldb < std::max(1, blas_param.
n))
251 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
252 }
253 if (blas_param.
ldc < std::max(1, blas_param.
m))
254 errorQuda(
"ldc=%d must be >= max(1,m=%d)", blas_param.
ldc, blas_param.
m);
255 } else {
257 if (blas_param.
lda < std::max(1, blas_param.
k))
258 errorQuda(
"lda=%d must be >= max(1,k=%d)", blas_param.
lda, blas_param.
k);
259 } else {
260 if (blas_param.
lda < std::max(1, blas_param.
m))
261 errorQuda(
"lda=%d must be >= max(1,m=%d)", blas_param.
lda, blas_param.
m);
262 }
264 if (blas_param.
ldb < std::max(1, blas_param.
n))
265 errorQuda(
"ldb=%d must be >= max(1,n=%d)", blas_param.
ldb, blas_param.
n);
266 } else {
267 if (blas_param.
ldb < std::max(1, blas_param.
k))
268 errorQuda(
"ldb=%d must be >= max(1,k=%d)", blas_param.
ldb, blas_param.
k);
269 }
270 if (blas_param.
ldc < std::max(1, blas_param.
n))
271 errorQuda(
"ldc=%d must be >= max(1,n=%d)", blas_param.
ldc, blas_param.
n);
272 }
273 //-------------------------------------------------------------------------
274
275 // Parse parameters for Eigen
276 //-------------------------------------------------------------------------
277 // Swap A and B if in column order
285 }
286
287 // Get maximum stride length to deduce the number of batches in the
288 // computation
290
291 // If the user gives strides of 0 for all arrays, we are essentially performing
292 // a GEMM on the first matrices in the array N_{batch} times.
293 // Give them what they ask for, YMMV...
294 // If this evaluates to -1, the user did not set any strides.
295 if (max_stride <= 0) max_stride = 1;
296
297 // Then number of GEMMs to compute
298 const uint64_t batch = blas_param.
batch_count / max_stride;
299
300 uint64_t data_size
302
304 data_size *= 2;
305 }
306
307 // Number of data between batches
308 unsigned int A_batch_size = blas_param.
lda * blas_param.
k;
310 unsigned int B_batch_size = blas_param.
ldb * blas_param.
n;
312 unsigned int C_batch_size = blas_param.
ldc * blas_param.
n;
313
314 // Data size of the entire array
315 size_t sizeAarr = A_batch_size * data_size * batch;
316 size_t sizeBarr = B_batch_size * data_size * batch;
317 size_t sizeCarr = C_batch_size * data_size * batch;
318
319 // If already on the host, just use the given pointer. If the data is on
320 // the device, allocate host memory and transfer
325 qudaMemcpy(A_h, A_data, sizeAarr, cudaMemcpyDeviceToHost);
326 qudaMemcpy(B_h, B_data, sizeBarr, cudaMemcpyDeviceToHost);
327 qudaMemcpy(C_h, C_data, sizeCarr, cudaMemcpyDeviceToHost);
328 }
329
331
332 typedef std::complex<double>
Z;
333 const Z alpha = blas_param.
alpha;
334 const Z beta = blas_param.
beta;
335 GEMM<MatrixXcd, Z>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
336
338
339 typedef std::complex<float> C;
340 const C alpha = blas_param.
alpha;
341 const C beta = blas_param.
beta;
342 GEMM<MatrixXcf, C>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
343
345
346 typedef double D;
347 const D alpha = (D)(
static_cast<std::complex<double>
>(blas_param.
alpha).real());
348 const D beta = (D)(
static_cast<std::complex<double>
>(blas_param.
beta).real());
349 GEMM<MatrixXd, D>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
350
352
353 typedef float S;
354 const S alpha = (S)(
static_cast<std::complex<float>
>(blas_param.
alpha).real());
355 const S beta = (S)(
static_cast<std::complex<float>
>(blas_param.
beta).real());
356 GEMM<MatrixXf, S>(A_h, B_h, C_h, alpha, beta, max_stride, blas_param);
357
358 } else {
360 }
361
362 // Restore the blas parameters to their original values
370 }
371
372 // Transfer data
374 qudaMemcpy(C_data, C_h, sizeCarr, cudaMemcpyHostToDevice);
378 }
379
381 gettimeofday(&
stop, NULL);
384 double time = ds + 0.000001 * dus;
386 printfQuda(
"Batched matrix GEMM completed in %f seconds with GFLOPS = %f\n", time, 1e-9 *
flops / time);
387
389 }
390 } // namespace generic
391 } // namespace blas_lapack
392 } // namespace quda
#define FLOPS_CGETRF(m_, n_)
#define FLOPS_ZGETRF(m_, n_)
enum QudaPrecision_s QudaPrecision
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
enum QudaFieldLocation_s QudaFieldLocation
@ QUDA_BLAS_DATAORDER_COL
#define pool_pinned_malloc(size)
#define pool_pinned_free(ptr)
void fillArray(EigenMat &EigenArr, T *arr, int rows, int cols, int ld, int offset, bool fill_eigen)
void invertEigen(std::complex< Float > *A_eig, std::complex< Float > *Ainv_eig, int n, uint64_t batch)
void init()
Create the BLAS context.
void GEMM(void *A_h, void *B_h, void *C_h, T alpha, T beta, int max_stride, QudaBLASParam &blas_param)
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const uint64_t batch, QudaPrecision precision, QudaFieldLocation location)
Batch inversion the matrix field using an LU decomposition method.
void destroy()
Destroy the BLAS context.
void set_native(bool native)
void stop()
Stop profiling.
void start()
Start profiling.
__host__ __device__ ValueType norm(const complex< ValueType > &z)
Returns the magnitude of z squared.
FloatingPoint< float > Float
#define qudaMemcpy(dst, src, count, kind)
#define qudaDeviceSynchronize()
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
DEVICEHOST void swap(Real &a, Real &b)
QudaVerbosity getVerbosity()