4
4
5
5
#define FATTN_KQ_STRIDE_TILE_F16 64
6
6
7
- template <int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
7
+ template <int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
8
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
9
__launch_bounds__ (nwarps*WARP_SIZE, 1 )
10
10
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
58
58
59
59
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
60
60
61
- const int ic0 = (blockIdx .x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
62
- const int ip = blockIdx .x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
61
+ const int ic0 = blockIdx .x * ncols; // Index of the Q/QKV column to work on.
63
62
64
63
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
65
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .y + nb01*ic0);
66
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx .y / gqa_ratio));
67
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx .y / gqa_ratio)); // K and V have same shape
64
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx .z + nb01*ic0);
65
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx .z / gqa_ratio));
66
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx .z / gqa_ratio)); // K and V have same shape
68
67
const half * maskh = (const half *) mask + ne11*ic0;
69
68
70
69
const int stride_KV2 = nb11 / sizeof (half2);
71
70
72
- const float slopef = get_alibi_slope (max_bias, blockIdx .y , n_head_log2, m0, m1);
71
+ const float slopef = get_alibi_slope (max_bias, blockIdx .z , n_head_log2, m0, m1);
73
72
const half slopeh = __float2half (slopef);
74
73
75
74
static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
105
104
106
105
__syncthreads ();
107
106
108
- const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
109
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
107
+ for (int k_VKQ_0 = blockIdx .y *FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim .y *FATTN_KQ_STRIDE_TILE_F16) {
110
108
// Calculate KQ tile and keep track of new maximum KQ values:
111
109
112
110
half kqmax_new[ncols/nwarps];
@@ -271,40 +269,40 @@ static __global__ void flash_attn_tile_ext_f16(
271
269
const int i0 = i00 + 2 *threadIdx .x ;
272
270
273
271
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2 *WARP_SIZE)];
274
- if (parallel_blocks == 1 ) {
272
+ if (gridDim . y == 1 ) {
275
273
dst_val /= __half2half2 (kqsum_j);
276
274
}
277
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip ;
278
- dst[j_dst*D*gridDim .y + D*blockIdx .y + i0 + 0 ] = __low2float (dst_val);
279
- dst[j_dst*D*gridDim .y + D*blockIdx .y + i0 + 1 ] = __high2float (dst_val);
275
+ const int j_dst = (ic0 + j_VKQ)*gridDim . y + blockIdx . y ;
276
+ dst[j_dst*D*gridDim .z + D*blockIdx .z + i0 + 0 ] = __low2float (dst_val);
277
+ dst[j_dst*D*gridDim .z + D*blockIdx .z + i0 + 1 ] = __high2float (dst_val);
280
278
}
281
279
282
- if (parallel_blocks != 1 && threadIdx .x == 0 ) {
283
- dst_meta[(ic0 + j_VKQ)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip ] = make_float2 (kqmax[j_VKQ_0/nwarps], kqsum_j);
280
+ if (gridDim . y != 1 && threadIdx .x == 0 ) {
281
+ dst_meta[(( ic0 + j_VKQ)*gridDim .z + blockIdx .z ) * gridDim . y + blockIdx . y ] = make_float2 (kqmax[j_VKQ_0/nwarps], kqsum_j);
284
282
}
285
283
}
286
284
#else
287
285
NO_DEVICE_CODE;
288
286
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
289
287
}
290
288
291
- template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
289
+ template <int cols_per_block, bool use_logit_softcap>
292
290
void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
291
const ggml_tensor * Q = dst->src [0 ];
294
292
switch (Q->ne [0 ]) {
295
293
case 64 : {
296
294
constexpr int D = 64 ;
297
295
constexpr int nwarps = 8 ;
298
296
constexpr size_t nbytes_shared = 0 ;
299
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300
- launch_fattn<D, cols_per_block, 1 , parallel_blocks, -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true );
297
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
298
+ launch_fattn<D, cols_per_block, 1 , -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true , false );
301
299
} break ;
302
300
case 128 : {
303
301
constexpr int D = 128 ;
304
302
constexpr int nwarps = 8 ;
305
303
constexpr size_t nbytes_shared = 0 ;
306
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
307
- launch_fattn<D, cols_per_block, 1 , parallel_blocks, -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true );
304
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
305
+ launch_fattn<D, cols_per_block, 1 , -1 >(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true , true , false );
308
306
} break ;
309
307
default : {
310
308
GGML_ABORT (" FlashAttention without tensor cores only supports head sizes 64 and 128." );
@@ -324,37 +322,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
324
322
325
323
if (Q->ne [1 ] <= 16 ) {
326
324
constexpr int cols_per_block = 16 ;
327
- constexpr int parallel_blocks = 4 ;
328
325
if (logit_softcap == 0 .0f ) {
329
326
constexpr bool use_logit_softcap = false ;
330
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
331
328
} else {
332
329
constexpr bool use_logit_softcap = true ;
333
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
334
- }
335
- return ;
336
- }
337
-
338
- if (Q->ne [1 ] <= 32 ) {
339
- constexpr int cols_per_block = 32 ;
340
- constexpr int parallel_blocks = 4 ;
341
- if (logit_softcap == 0 .0f ) {
342
- constexpr bool use_logit_softcap = false ;
343
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344
- } else {
345
- constexpr bool use_logit_softcap = true ;
346
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
330
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
347
331
}
348
332
return ;
349
333
}
350
334
351
335
constexpr int cols_per_block = 32 ;
352
- constexpr int parallel_blocks = 1 ;
353
336
if (logit_softcap == 0 .0f ) {
354
337
constexpr bool use_logit_softcap = false ;
355
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
338
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
356
339
} else {
357
340
constexpr bool use_logit_softcap = true ;
358
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
341
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
359
342
}
360
343
}
0 commit comments