QUDA: quda/lib/interface/blas_interface.cpp Source File

QUDA  v1.1.0
A library for QCD on GPUs
blas_interface.cpp
Go to the documentation of this file.
1 #include <quda.h>
2 #include <blas_lapack.h>
3 #include <tune_quda.h>
4 
5 using namespace quda;
6 
7 // Forward declarations for profiling and parameter checking
8 // The helper functions are defined in interface_quda.cpp
9 TimeProfile &getProfileBLAS();
10 void checkBLASParam(QudaBLASParam &param);
11 
12  void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
13 {
14  getProfileBLAS().TPSTART(QUDA_PROFILE_TOTAL);
15  checkBLASParam(*blas_param);
16 
17  // cuBLAS works exclusively in column major order. If the input data is in
18  // row major order, we may treat the A and B and C arrays as A^T, B^T, and C^T.
19  // We swap the order of the A * B multiplication and swap the
20  // operation types and other data to recover the the desired result in the
21  // desired order.
22  // E.g: in row major, the operation,
23  // C = a * A^T * B + b * C
24  //
25  // will become the column major operation
26  // C^T = a * B^T * A + b * C^T
27  //
28  // By inspection, one can see that transposition of the above column major
29  // operation will result in the desired row major answer:
30  //
31  // (C^T)^T = a * (B^T * A)^T + b * (C^T)^T
32  // --> C = a * A^T * B + b * C
33  //
34  // We must also swap around some parameters. The Row major indices,
35  // A_{m, lda}, B_{k, ldb}, C_{m, ldc}
36  // become
37  // A^T_{lda, m}, B^T_{ldb, k}, C^T_{ldc, m}.
38  // so the leading dimensions remain the same. However, we must change the actual
39  // matrix dims m,n,k to reflect the change to column major.
40  // m_{col} = n_{row}
41  // n_{col} = m_{row}
42  // k_{col} = k_{row}
43  // And because we are swapping the A and B arrays, we must also swap their
44  // leading dim values and any offsets. All this is done behind the scenes in the
45  // BatchGEMM function, and before function exit all pointers and values are
46  // restored to the values they had on entry.
47 
48  if (use_native == QUDA_BOOLEAN_FALSE) {
49  getProfileBLAS().TPSTART(QUDA_PROFILE_COMPUTE);
50  blas_lapack::generic::stridedBatchGEMM(arrayA, arrayB, arrayC, *blas_param, QUDA_CPU_FIELD_LOCATION);
51  getProfileBLAS().TPSTOP(QUDA_PROFILE_COMPUTE);
52  } else {
53  getProfileBLAS().TPSTART(QUDA_PROFILE_INIT);
54 
55  // The data in the arrays is on the host. We transfer the data to the device here
56  // for timing purposes. One can pass host pointers to the BatchGEMM function
57  // and it will handle the data movement for the user.
58 
59  // Extract data from the param struct for device malloc
60  uint64_t arrayA_size = 0, arrayB_size = 0, arrayC_size = 0;
61  if (blas_param->data_order == QUDA_BLAS_DATAORDER_COL) {
62  // leading dimension is in terms of consecutive data
63  // elements in a column, multiplied by number of rows
64  if (blas_param->trans_a == QUDA_BLAS_OP_N) {
65  arrayA_size = blas_param->lda * blas_param->k; // A_mk
66  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->lda, blas_param->k);
67  } else {
68  arrayA_size = blas_param->lda * blas_param->m; // A_km
69  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->lda, blas_param->m);
70  }
71 
72  if (blas_param->trans_b == QUDA_BLAS_OP_N) {
73  arrayB_size = blas_param->ldb * blas_param->n; // B_kn
74  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->ldb, blas_param->n);
75  } else {
76  arrayB_size = blas_param->ldb * blas_param->k; // B_nk
77  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->ldb, blas_param->k);
78  }
79  arrayC_size = blas_param->ldc * blas_param->n; // C_mn
80  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array C_{%d, %d}\n", blas_param->ldc, blas_param->n);
81  } else {
82  // leading dimension is in terms of consecutive data
83  // elements in a row, multiplied by number of columns.
84  if (blas_param->trans_a == QUDA_BLAS_OP_N) {
85  arrayA_size = blas_param->lda * blas_param->m; // A_mk
86  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->m, blas_param->lda);
87  } else {
88  arrayA_size = blas_param->lda * blas_param->k; // A_km
89  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->k, blas_param->lda);
90  }
91  if (blas_param->trans_b == QUDA_BLAS_OP_N) {
92  arrayB_size = blas_param->ldb * blas_param->k; // B_nk
93  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->k, blas_param->ldb);
94  } else {
95  arrayB_size = blas_param->ldb * blas_param->n; // B_kn
96  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->n, blas_param->ldb);
97  }
98  arrayC_size = blas_param->ldc * blas_param->m; // C_mn
99  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array C_{%d, %d}\n", blas_param->m, blas_param->ldc);
100  }
101 
102  size_t data_size = (blas_param->data_type == QUDA_BLAS_DATATYPE_D || blas_param->data_type == QUDA_BLAS_DATATYPE_Z) ?
103  sizeof(double) :
104  sizeof(float);
105  int re_im = 1;
106  if (blas_param->data_type == QUDA_BLAS_DATATYPE_C || blas_param->data_type == QUDA_BLAS_DATATYPE_Z) { re_im *= 2; }
107 
108  // If the user passes non-zero offsets, add one extra
109  // matrix to the device array to accomodate it.
110  int batches_extra = 0;
111  if (blas_param->a_offset + blas_param->b_offset + blas_param->c_offset > 0) { batches_extra++; }
112  int batches = blas_param->batch_count + batches_extra;
113 
114  size_t A_bytes = batches * arrayA_size * re_im * data_size;
115  size_t B_bytes = batches * arrayB_size * re_im * data_size;
116  size_t C_bytes = batches * arrayC_size * re_im * data_size;
117  if (getVerbosity() >= QUDA_VERBOSE)
118  printfQuda("A_Gbtyes = %f, B_Gbtyes = %f, C_Gbtyes = %f\n", 1.0 * A_bytes / std::pow(1024, 3),
119  1.0 * B_bytes / std::pow(1024, 3), 1.0 * C_bytes / std::pow(1024, 3));
120  void *A_d = pool_device_malloc(A_bytes);
121  void *B_d = pool_device_malloc(B_bytes);
122  void *C_d = pool_device_malloc(C_bytes);
123  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("QUDA: arrays allocated sucessfully.\n");
124  getProfileBLAS().TPSTOP(QUDA_PROFILE_INIT);
125 
126  // Transfer host data to device
127  getProfileBLAS().TPSTART(QUDA_PROFILE_H2D);
128  qudaMemcpy(A_d, arrayA, A_bytes, cudaMemcpyHostToDevice);
129  qudaMemcpy(B_d, arrayB, B_bytes, cudaMemcpyHostToDevice);
130  qudaMemcpy(C_d, arrayC, C_bytes, cudaMemcpyHostToDevice);
131  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("QUDA: arrays copied susessfully.\n");
132  getProfileBLAS().TPSTOP(QUDA_PROFILE_H2D);
133 
134  // Compute Batched GEMM
135  getProfileBLAS().TPSTART(QUDA_PROFILE_COMPUTE);
136 
137  blas_lapack::native::stridedBatchGEMM(A_d, B_d, C_d, *blas_param, QUDA_CUDA_FIELD_LOCATION);
138 
139  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BatchGEMM success!\n");
140  getProfileBLAS().TPSTOP(QUDA_PROFILE_COMPUTE);
141 
142  // Copy device C array back to host
143  getProfileBLAS().TPSTART(QUDA_PROFILE_D2H);
144  qudaMemcpy(arrayC, C_d, C_bytes, cudaMemcpyDeviceToHost);
145  getProfileBLAS().TPSTOP(QUDA_PROFILE_D2H);
146 
147  // Clean up
148  getProfileBLAS().TPSTART(QUDA_PROFILE_FREE);
149  pool_device_free(A_d);
150  pool_device_free(B_d);
151  pool_device_free(C_d);
152  getProfileBLAS().TPSTOP(QUDA_PROFILE_FREE);
153  }
154 
155  getProfileBLAS().TPSTOP(QUDA_PROFILE_TOTAL);
156  saveTuneCache();
157 }
getProfileBLAS
TimeProfile & getProfileBLAS()
Profiler for covariant derivative.
Definition: interface_quda.cpp:227
blasGEMMQuda
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
Strided Batched GEMM.
Definition: blas_interface.cpp:12
checkBLASParam
void checkBLASParam(QudaBLASParam &param)
Definition: interface_quda.cpp:56
QUDA_CUDA_FIELD_LOCATION
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
QUDA_CPU_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
Definition: enum_quda.h:325
QUDA_VERBOSE
@ QUDA_VERBOSE
Definition: enum_quda.h:267
QUDA_BLAS_DATATYPE_Z
@ QUDA_BLAS_DATATYPE_Z
Definition: enum_quda.h:480
QUDA_BLAS_DATATYPE_D
@ QUDA_BLAS_DATATYPE_D
Definition: enum_quda.h:478
QUDA_BLAS_DATATYPE_C
@ QUDA_BLAS_DATATYPE_C
Definition: enum_quda.h:479
QUDA_BOOLEAN_FALSE
@ QUDA_BOOLEAN_FALSE
Definition: enum_quda.h:460
QudaBoolean
enum QudaBoolean_s QudaBoolean
QUDA_BLAS_DATAORDER_COL
@ QUDA_BLAS_DATAORDER_COL
Definition: enum_quda.h:486
QUDA_BLAS_OP_N
@ QUDA_BLAS_OP_N
Definition: enum_quda.h:470
pool_device_malloc
#define pool_device_malloc(size)
Definition: malloc_quda.h:170
pool_device_free
#define pool_device_free(ptr)
Definition: malloc_quda.h:171
quda::blas_lapack::generic::stridedBatchGEMM
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
Definition: blas_lapack_eigen.cpp:204
quda::blas_lapack::native::stridedBatchGEMM
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
Definition: blas_lapack_cublas.cpp:193
quda
Definition: blas_lapack.h:24
quda::saveTuneCache
void saveTuneCache(bool error=false)
Definition: tune.cpp:439
quda::QUDA_PROFILE_INIT
@ QUDA_PROFILE_INIT
Definition: timer.h:106
quda::QUDA_PROFILE_COMPUTE
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
quda::QUDA_PROFILE_TOTAL
@ QUDA_PROFILE_TOTAL
Definition: timer.h:149
quda::QUDA_PROFILE_FREE
@ QUDA_PROFILE_FREE
Definition: timer.h:111
quda::QUDA_PROFILE_H2D
@ QUDA_PROFILE_H2D
Definition: timer.h:104
quda::QUDA_PROFILE_D2H
@ QUDA_PROFILE_D2H
Definition: timer.h:105
quda::pow
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
param
QudaGaugeParam param
Definition: pack_test.cpp:18
quda.h
Main header file for the QUDA library.
qudaMemcpy
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_api.h:204
QudaBLASParam_s::c_offset
int c_offset
Definition: quda.h:761
QudaBLASParam_s::ldc
int ldc
Definition: quda.h:758
QudaBLASParam_s::data_order
QudaBLASDataOrder data_order
Definition: quda.h:772
QudaBLASParam_s::b_offset
int b_offset
Definition: quda.h:760
QudaBLASParam_s::trans_a
QudaBLASOperation trans_a
Definition: quda.h:751
QudaBLASParam_s::ldb
int ldb
Definition: quda.h:757
QudaBLASParam_s::data_type
QudaBLASDataType data_type
Definition: quda.h:771
QudaBLASParam_s::a_offset
int a_offset
Definition: quda.h:759
QudaBLASParam_s::lda
int lda
Definition: quda.h:756
QudaBLASParam_s::batch_count
int batch_count
Definition: quda.h:769
QudaBLASParam_s::n
int n
Definition: quda.h:754
QudaBLASParam_s::m
int m
Definition: quda.h:753
QudaBLASParam_s::trans_b
QudaBLASOperation trans_b
Definition: quda.h:752
QudaBLASParam_s::k
int k
Definition: quda.h:755
printfQuda
#define printfQuda(...)
Definition: util_quda.h:114
getVerbosity
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21

Generated on Thu Oct 28 2021 16:10:27 for QUDA by doxygen 1.9.1

AltStyle によって変換されたページ (->オリジナル) /