Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit becade5

Browse files
hjc4869JohannesGaesslerbjj
authored
HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (ggml-org#12032)
Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check Adds rocWMMA support to fattn-wmma-f16 --- Signed-off-by: Carl Klemm <carl@uvos.xyz> Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Ben Jackson <ben@ben.com>
1 parent dfd6b2c commit becade5

File tree

6 files changed

+145
-95
lines changed

6 files changed

+145
-95
lines changed

‎ggml/CMakeLists.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
162162
option(GGML_HIP "ggml: use HIP" OFF)
163163
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
164164
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
165+
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
165166
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
166167
option(GGML_VULKAN "ggml: use Vulkan" OFF)
167168
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)

‎ggml/src/ggml-cuda/common.cuh‎

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
6363
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
6464

65+
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
6566
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
6667
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
6768
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
@@ -196,6 +197,10 @@ typedef float2 dfloat2;
196197
#define FP16_MMA_AVAILABLE
197198
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
198199

200+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
201+
#define FP16_MMA_AVAILABLE
202+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
203+
199204
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
200205
#define NEW_MMA_AVAILABLE
201206
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
223228

224229
// Any FP16 tensor core instructions are available for ggml code.
225230
static bool fp16_mma_available(const int cc) {
226-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
231+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
232+
return false;
233+
#else
234+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
236+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
227237
}
228238

229239
// To be used for feature selection of external libraries, e.g. cuBLAS.
230240
static bool fp16_mma_hardware_available(const int cc) {
231-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
241+
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
232243
}
233244

234245
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

‎ggml/src/ggml-cuda/fattn-common.cuh‎

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,36 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
5858

5959
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
6061
GGML_UNUSED(Q_v);
6162

6263
T sum = 0.0f;
6364

6465
#pragma unroll
65-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
66+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
6667
const int k_KQ = k_KQ_0 + threadIdx.x;
6768

6869
const int ib = k_KQ / QI8_1;
6970
const int iqs4 = k_KQ % QI4_0;
7071
const int shift = k_KQ & (QI8_1/2);
7172

7273
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
73-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
74+
const int u = Q_q8[k_KQ_0/warp_size];
7475

7576
const int sumi = ggml_cuda_dp4a(v, u, 0);
7677

7778
#ifdef FP16_AVAILABLE
7879
if (std::is_same<T, half>::value) {
7980
const half2 * Q_ds = (const half2 *) Q_ds_v;
8081

81-
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
82+
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
8283
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
8384
} else
8485
#endif // FP16_AVAILABLE
8586
{
8687
const float2 * Q_ds = (const float2 *) Q_ds_v;
8788

88-
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
89+
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
8990
}
9091
}
9192

@@ -97,37 +98,38 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9798
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
9899

99100
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
100102
GGML_UNUSED(Q_v);
101103

102104
T sum = 0.0f;
103105

104106
#pragma unroll
105-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
107+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
106108
const int k_KQ = k_KQ_0 + threadIdx.x;
107109

108110
const int ib = k_KQ / QI8_1;
109111
const int iqs4 = k_KQ % QI4_1;
110112
const int shift = k_KQ & (QI8_1/2);
111113

112114
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
113-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
115+
const int u = Q_q8[k_KQ_0/warp_size];
114116

115117
const int sumi = ggml_cuda_dp4a(v, u, 0);
116118

117119
#ifdef FP16_AVAILABLE
118120
if (std::is_same<T, half>::value) {
119121
const half2 * Q_ds = (const half2 *) Q_ds_v;
120122

121-
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
123+
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
122124
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
123125
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
124126
} else
125127
#endif // FP16_AVAILABLE
126128
{
127129
const float2 * Q_ds = (const float2 *) Q_ds_v;
128130

129-
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
130-
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
131+
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
132+
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
131133

132134
sum += (T) (sumid4d8 + m4s8scaled);
133135
}
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141143
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142144

143145
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
144147
GGML_UNUSED(Q_v);
145148

146149
T sum = 0.0f;
147150

148151
#pragma unroll
149-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
152+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
150153
const int k_KQ = k_KQ_0 + threadIdx.x;
151154

152155
const int ib = k_KQ / QI8_1;
@@ -161,22 +164,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161164
v |= (vh << 18) & 0x00100000; // 2 -> 20
162165
v |= (vh << 25) & 0x10000000; // 3 -> 28
163166

164-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
167+
const int u = Q_q8[k_KQ_0/warp_size];
165168

166169
const int sumi = ggml_cuda_dp4a(v, u, 0);
167170

168171
#ifdef FP16_AVAILABLE
169172
if (std::is_same<T, half>::value) {
170173
const half2 * Q_ds = (const half2 *) Q_ds_v;
171174

172-
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
175+
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
173176
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
174177
} else
175178
#endif // FP16_AVAILABLE
176179
{
177180
const float2 * Q_ds = (const float2 *) Q_ds_v;
178181

179-
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
182+
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
180183
}
181184
}
182185

@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188191
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189192

190193
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
191195
GGML_UNUSED(Q_v);
192196

193197
T sum = 0.0f;
194198

195199
#pragma unroll
196-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
200+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
197201
const int k_KQ = k_KQ_0 + threadIdx.x;
198202

199203
const int ib = k_KQ / QI8_1;
@@ -208,24 +212,24 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208212
v |= (vh << 18) & 0x00100000; // 2 -> 20
209213
v |= (vh << 25) & 0x10000000; // 3 -> 28
210214

211-
const int u = Q_q8[k_KQ_0/WARP_SIZE];
215+
const int u = Q_q8[k_KQ_0/warp_size];
212216

213217
const int sumi = ggml_cuda_dp4a(v, u, 0);
214218

215219
#ifdef FP16_AVAILABLE
216220
if (std::is_same<T, half>::value) {
217221
const half2 * Q_ds = (const half2 *) Q_ds_v;
218222

219-
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
223+
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
220224
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
221225
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
222226
} else
223227
#endif // FP16_AVAILABLE
224228
{
225229
const float2 * Q_ds = (const float2 *) Q_ds_v;
226230

227-
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
228-
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
231+
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
232+
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
229233

230234
sum += (T) (sumid5d8 + m5s8scaled);
231235
}
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239243
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240244

241245
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
242247
GGML_UNUSED(Q_v);
243248

244249
T sum = 0.0f;
245250

246251
#pragma unroll
247-
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
252+
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
248253
const int k_KQ = k_KQ_0 + threadIdx.x;
249254

250255
const int ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255260
T Q_d;
256261
if (std::is_same<T, half>::value) {
257262
const half2 * Q_ds = (const half2 *) Q_ds_v;
258-
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
263+
Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
259264
} else {
260265
const float2 * Q_ds = (const float2 *) Q_ds_v;
261-
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
266+
Q_d = Q_ds[k_KQ_0/warp_size].x;
262267
}
263268

264-
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
269+
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
265270
}
266271

267272
return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272277
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273278

274279
const half2 * K_h2 = (const half2 *) K_c;
280+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
275281
GGML_UNUSED(Q_q8);
276282
GGML_UNUSED(Q_ds_v);
277283

@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282288
half2 sum2 = make_half2(0.0f, 0.0f);
283289

284290
#pragma unroll
285-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
291+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
286292
const int k_KQ = k_KQ_0 + threadIdx.x;
287293

288294
const half2 K_ik = K_h2[k_KQ];
289-
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
295+
sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
290296
}
291297

292298
return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298304
float sum = 0.0f;
299305

300306
#pragma unroll
301-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
307+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
302308
const int k_KQ = k_KQ_0 + threadIdx.x;
303309

304310
const half2 K_ik = K_h2[k_KQ];
305-
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
306-
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
311+
sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
312+
sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
307313
}
308314

309315
return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698704

699705
GGML_ASSERT(Q->ne[3] == 1);
700706

707+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708+
701709
ggml_cuda_pool & pool = ctx.pool();
702710
cudaStream_t main_stream = ctx.stream();
703711
const int id = ggml_cuda_get_device();
@@ -750,7 +758,7 @@ void launch_fattn(
750758
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751759
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
752760

753-
const dim3 block_dim(WARP_SIZE, nwarps, 1);
761+
const dim3 block_dim(warp_size, nwarps, 1);
754762
dim3 blocks_num;
755763
if (parallel_blocks == 0) {
756764
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796804
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
797805
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
798806

807+
GGML_ASSERT(block_dim.x % warp_size == 0);
808+
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
799809
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
800810
(const char *) Q->data,
801811
K_data,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /