4
6
7 // Forward declarations for profiling and parameter checking
8 // The helper functions are defined in interface_quda.cpp
11
13 {
16
17 // cuBLAS works exclusively in column major order. If the input data is in
18 // row major order, we may treat the A and B and C arrays as A^T, B^T, and C^T.
19 // We swap the order of the A * B multiplication and swap the
20 // operation types and other data to recover the the desired result in the
21 // desired order.
22 // E.g: in row major, the operation,
23 // C = a * A^T * B + b * C
24 //
25 // will become the column major operation
26 // C^T = a * B^T * A + b * C^T
27 //
28 // By inspection, one can see that transposition of the above column major
29 // operation will result in the desired row major answer:
30 //
31 // (C^T)^T = a * (B^T * A)^T + b * (C^T)^T
32 // --> C = a * A^T * B + b * C
33 //
34 // We must also swap around some parameters. The Row major indices,
35 // A_{m, lda}, B_{k, ldb}, C_{m, ldc}
36 // become
37 // A^T_{lda, m}, B^T_{ldb, k}, C^T_{ldc, m}.
38 // so the leading dimensions remain the same. However, we must change the actual
39 // matrix dims m,n,k to reflect the change to column major.
40 // m_{col} = n_{row}
41 // n_{col} = m_{row}
42 // k_{col} = k_{row}
43 // And because we are swapping the A and B arrays, we must also swap their
44 // leading dim values and any offsets. All this is done behind the scenes in the
45 // BatchGEMM function, and before function exit all pointers and values are
46 // restored to the values they had on entry.
47
52 } else {
54
55 // The data in the arrays is on the host. We transfer the data to the device here
56 // for timing purposes. One can pass host pointers to the BatchGEMM function
57 // and it will handle the data movement for the user.
58
59 // Extract data from the param struct for device malloc
60 uint64_t arrayA_size = 0, arrayB_size = 0, arrayC_size = 0;
62 // leading dimension is in terms of consecutive data
63 // elements in a column, multiplied by number of rows
65 arrayA_size = blas_param->
lda * blas_param->
k;
// A_mk
67 } else {
68 arrayA_size = blas_param->
lda * blas_param->
m;
// A_km
70 }
71
73 arrayB_size = blas_param->
ldb * blas_param->
n;
// B_kn
75 } else {
76 arrayB_size = blas_param->
ldb * blas_param->
k;
// B_nk
78 }
79 arrayC_size = blas_param->
ldc * blas_param->
n;
// C_mn
81 } else {
82 // leading dimension is in terms of consecutive data
83 // elements in a row, multiplied by number of columns.
85 arrayA_size = blas_param->
lda * blas_param->
m;
// A_mk
87 } else {
88 arrayA_size = blas_param->
lda * blas_param->
k;
// A_km
90 }
92 arrayB_size = blas_param->
ldb * blas_param->
k;
// B_nk
94 } else {
95 arrayB_size = blas_param->
ldb * blas_param->
n;
// B_kn
97 }
98 arrayC_size = blas_param->
ldc * blas_param->
m;
// C_mn
100 }
101
103 sizeof(double) :
104 sizeof(float);
105 int re_im = 1;
107
108 // If the user passes non-zero offsets, add one extra
109 // matrix to the device array to accomodate it.
110 int batches_extra = 0;
112 int batches = blas_param->
batch_count + batches_extra;
113
114 size_t A_bytes = batches * arrayA_size * re_im * data_size;
115 size_t B_bytes = batches * arrayB_size * re_im * data_size;
116 size_t C_bytes = batches * arrayC_size * re_im * data_size;
118 printfQuda(
"A_Gbtyes = %f, B_Gbtyes = %f, C_Gbtyes = %f\n", 1.0 * A_bytes /
std::pow(1024, 3),
125
126 // Transfer host data to device
128 qudaMemcpy(A_d, arrayA, A_bytes, cudaMemcpyHostToDevice);
129 qudaMemcpy(B_d, arrayB, B_bytes, cudaMemcpyHostToDevice);
130 qudaMemcpy(C_d, arrayC, C_bytes, cudaMemcpyHostToDevice);
133
134 // Compute Batched GEMM
136
138
141
142 // Copy device C array back to host
144 qudaMemcpy(arrayC, C_d, C_bytes, cudaMemcpyDeviceToHost);
146
147 // Clean up
153 }
154
157 }
TimeProfile & getProfileBLAS()
Profiler for covariant derivative.
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
Strided Batched GEMM.
void checkBLASParam(QudaBLASParam ¶m)
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
enum QudaBoolean_s QudaBoolean
@ QUDA_BLAS_DATAORDER_COL
#define pool_device_malloc(size)
#define pool_device_free(ptr)
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void saveTuneCache(bool error=false)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Main header file for the QUDA library.
#define qudaMemcpy(dst, src, count, kind)
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
QudaVerbosity getVerbosity()