@@ -57,35 +57,36 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
5858
5959 const  block_q4_0 * K_q4_0 = (const  block_q4_0 *) K_c;
60+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
6061 GGML_UNUSED (Q_v);
6162
6263 T sum = 0 .0f ;
6364
6465#pragma  unroll
65-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
66+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
6667 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
6768
6869 const  int  ib = k_KQ / QI8_1;
6970 const  int  iqs4 = k_KQ % QI4_0;
7071 const  int  shift = k_KQ & (QI8_1/2 );
7172
7273 const  int  v = (get_int_b2 (K_q4_0[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
73-  const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
74+  const  int  u = Q_q8[k_KQ_0/warp_size ];
7475
7576 const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
7677
7778#ifdef  FP16_AVAILABLE
7879 if  (std::is_same<T, half>::value) {
7980 const  half2 * Q_ds = (const  half2 *) Q_ds_v;
8081
81-  const  half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
82+  const  half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
8283 sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2) /*  *8/QI8_1 == 1 */ 
8384 } else 
8485#endif  //  FP16_AVAILABLE
8586 {
8687 const  float2  * Q_ds = (const  float2  *) Q_ds_v;
8788
88-  sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x  - (8 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
89+  sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x  - (8 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
8990 }
9091 }
9192
@@ -97,37 +98,38 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9798 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
9899
99100 const  block_q4_1 * K_q4_1 = (const  block_q4_1 *) K_c;
101+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
100102 GGML_UNUSED (Q_v);
101103
102104 T sum = 0 .0f ;
103105
104106#pragma  unroll
105-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
107+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
106108 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
107109
108110 const  int  ib = k_KQ / QI8_1;
109111 const  int  iqs4 = k_KQ % QI4_1;
110112 const  int  shift = k_KQ & (QI8_1/2 );
111113
112114 const  int  v = (get_int_b4 (K_q4_1[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
113-  const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
115+  const  int  u = Q_q8[k_KQ_0/warp_size ];
114116
115117 const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
116118
117119#ifdef  FP16_AVAILABLE
118120 if  (std::is_same<T, half>::value) {
119121 const  half2 * Q_ds = (const  half2 *) Q_ds_v;
120122
121-  const  half2 d4d8_m4s8 = K_q4_1[ib].dm  * Q_ds[k_KQ_0/WARP_SIZE ];
123+  const  half2 d4d8_m4s8 = K_q4_1[ib].dm  * Q_ds[k_KQ_0/warp_size ];
122124 const  half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2 (sumi, 1 .0f /QI8_1);
123125 sum += (T) (__low2half (sumid4d8_m4s8scaled) + __high2half (sumid4d8_m4s8scaled));
124126 } else 
125127#endif  //  FP16_AVAILABLE
126128 {
127129 const  float2  * Q_ds = (const  float2  *) Q_ds_v;
128130
129-  const  float  sumid4d8 = __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x  * sumi;
130-  const  float  m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y  / QI8_1;
131+  const  float  sumid4d8 = __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x  * sumi;
132+  const  float  m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y  / QI8_1;
131133
132134 sum += (T) (sumid4d8 + m4s8scaled);
133135 }
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141143 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
142144
143145 const  block_q5_0 * K_q5_0 = (const  block_q5_0 *) K_c;
146+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
144147 GGML_UNUSED (Q_v);
145148
146149 T sum = 0 .0f ;
147150
148151#pragma  unroll
149-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
152+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
150153 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
151154
152155 const  int  ib = k_KQ / QI8_1;
@@ -161,22 +164,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161164 v |= (vh << 18 ) & 0x00100000 ; //  2 -> 20
162165 v |= (vh << 25 ) & 0x10000000 ; //  3 -> 28
163166
164-  const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
167+  const  int  u = Q_q8[k_KQ_0/warp_size ];
165168
166169 const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
167170
168171#ifdef  FP16_AVAILABLE
169172 if  (std::is_same<T, half>::value) {
170173 const  half2 * Q_ds = (const  half2 *) Q_ds_v;
171174
172-  const  half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
175+  const  half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
173176 sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2)*__float2half (2 .0f )) /*  *16/QI8_1 == 2 */ 
174177 } else 
175178#endif  //  FP16_AVAILABLE
176179 {
177180 const  float2  * Q_ds = (const  float2  *) Q_ds_v;
178181
179-  sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x  - (16 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
182+  sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x  - (16 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
180183 }
181184 }
182185
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188191 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
189192
190193 const  block_q5_1 * K_q5_1 = (const  block_q5_1 *) K_c;
194+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
191195 GGML_UNUSED (Q_v);
192196
193197 T sum = 0 .0f ;
194198
195199#pragma  unroll
196-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
200+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
197201 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
198202
199203 const  int  ib = k_KQ / QI8_1;
@@ -208,24 +212,24 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208212 v |= (vh << 18 ) & 0x00100000 ; //  2 -> 20
209213 v |= (vh << 25 ) & 0x10000000 ; //  3 -> 28
210214
211-  const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
215+  const  int  u = Q_q8[k_KQ_0/warp_size ];
212216
213217 const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
214218
215219#ifdef  FP16_AVAILABLE
216220 if  (std::is_same<T, half>::value) {
217221 const  half2 * Q_ds = (const  half2 *) Q_ds_v;
218222
219-  const  half2 d5d8_m5s8 = K_q5_1[ib].dm  * Q_ds[k_KQ_0/WARP_SIZE ];
223+  const  half2 d5d8_m5s8 = K_q5_1[ib].dm  * Q_ds[k_KQ_0/warp_size ];
220224 const  half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2 (sumi, 1 .0f /QI8_1);
221225 sum += (T) (__low2half (sumid5d8_m5s8scaled) + __high2half (sumid5d8_m5s8scaled));
222226 } else 
223227#endif  //  FP16_AVAILABLE
224228 {
225229 const  float2  * Q_ds = (const  float2  *) Q_ds_v;
226230
227-  const  float  sumid5d8 = __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x  * sumi;
228-  const  float  m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y  / QI8_1;
231+  const  float  sumid5d8 = __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x  * sumi;
232+  const  float  m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y  / QI8_1;
229233
230234 sum += (T) (sumid5d8 + m5s8scaled);
231235 }
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239243 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
240244
241245 const  block_q8_0 * K_q8_0 = (const  block_q8_0 *) K_c;
246+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
242247 GGML_UNUSED (Q_v);
243248
244249 T sum = 0 .0f ;
245250
246251#pragma  unroll
247-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
252+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
248253 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
249254
250255 const  int  ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255260 T Q_d;
256261 if  (std::is_same<T, half>::value) {
257262 const  half2 * Q_ds = (const  half2 *) Q_ds_v;
258-  Q_d = __low2half (Q_ds[k_KQ_0/WARP_SIZE ]);
263+  Q_d = __low2half (Q_ds[k_KQ_0/warp_size ]);
259264 } else  {
260265 const  float2  * Q_ds = (const  float2  *) Q_ds_v;
261-  Q_d = Q_ds[k_KQ_0/WARP_SIZE ].x ;
266+  Q_d = Q_ds[k_KQ_0/warp_size ].x ;
262267 }
263268
264-  sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/WARP_SIZE ], K_q8_0[ib].d , Q_d);
269+  sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/warp_size ], K_q8_0[ib].d , Q_d);
265270 }
266271
267272 return  sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272277 const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8 , const  void  * __restrict__  Q_ds_v) {
273278
274279 const  half2 * K_h2 = (const  half2 *) K_c;
280+  constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
275281 GGML_UNUSED (Q_q8);
276282 GGML_UNUSED (Q_ds_v);
277283
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282288 half2 sum2 = make_half2 (0 .0f , 0 .0f );
283289
284290#pragma  unroll
285-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
291+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
286292 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
287293
288294 const  half2 K_ik = K_h2[k_KQ];
289-  sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE ];
295+  sum2 += K_ik * Q_h2[k_KQ_0/warp_size ];
290296 }
291297
292298 return  __low2half (sum2) + __high2half (sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298304 float  sum = 0 .0f ;
299305
300306#pragma  unroll
301-  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
307+  for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
302308 const  int  k_KQ = k_KQ_0 + threadIdx .x ;
303309
304310 const  half2 K_ik = K_h2[k_KQ];
305-  sum += __low2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].x ;
306-  sum += __high2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].y ;
311+  sum += __low2float (K_ik) * Q_f2[k_KQ_0/warp_size ].x ;
312+  sum += __high2float (K_ik) * Q_f2[k_KQ_0/warp_size ].y ;
307313 }
308314
309315 return  sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698704
699705 GGML_ASSERT (Q->ne [3 ] == 1 );
700706
707+  const  int  warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708+ 701709 ggml_cuda_pool & pool = ctx.pool ();
702710 cudaStream_t main_stream = ctx.stream ();
703711 const  int  id = ggml_cuda_get_device ();
@@ -750,7 +758,7 @@ void launch_fattn(
750758 const  int  ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1);
751759 const  int  ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
752760
753-  const  dim3  block_dim (WARP_SIZE , nwarps, 1 );
761+  const  dim3  block_dim (warp_size , nwarps, 1 );
754762 dim3  blocks_num;
755763 if  (parallel_blocks == 0 ) {
756764 //  For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796804 const  float  m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
797805 const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
798806
807+  GGML_ASSERT (block_dim.x  % warp_size == 0 );
808+  GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x  * block_dim.y  <= 4  * (unsigned  int )warp_size);
799809 fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
800810 (const  char  *) Q->data ,
801811 K_data,
0 commit comments