Skip to content

Commit 76881ac

Browse files
committed
CUDA: Improve flash decoding kernel occupancy for BS=1 case
Addresses issue: #12182 . This PR adds the following optimizations to the CUDA flash decoding code: - Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 This results in upto 15% perf improvement in gen phase throughput for large seq lengths.
1 parent 20a9b8f commit 76881ac

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,72 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308308

309309
if (Q->ne[1] == 1) {
310310
constexpr int cols_per_block = 1;
311-
constexpr int parallel_blocks = 4;
311+
const int total_blocks = (((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]);
312+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
313+
const int seqlen_tiles = (K->ne[1] + D - 1) / D;
314+
312315
if (logit_softcap == 0.0f) {
313316
constexpr bool use_logit_softcap = false;
314-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
315-
} else {
317+
318+
// Determine the number of active blocks per SM
319+
// parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320+
int numActiveBlocks = 1;
321+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
322+
323+
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324+
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325+
// If there are not enough tiles to process, we can reduce the number of blocks
326+
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327+
328+
if(parallel_blocks >= 24)
329+
{
330+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
331+
}
332+
else if(parallel_blocks >= 16)
333+
{
334+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16, type_K, type_V, use_logit_softcap>(ctx, dst);
335+
}
336+
else if(parallel_blocks >= 12)
337+
{
338+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12, type_K, type_V, use_logit_softcap>(ctx, dst);
339+
}
340+
else if(parallel_blocks >= 8)
341+
{
342+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
343+
}
344+
else
345+
{
346+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
347+
}
348+
}
349+
else
350+
{
316351
constexpr bool use_logit_softcap = true;
317-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
352+
int numActiveBlocks = 1;
353+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>, D, 0));
354+
355+
const int parallel_blocks = std::min((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356+
357+
if(parallel_blocks >= 24)
358+
{
359+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24, type_K, type_V, use_logit_softcap>(ctx, dst);
360+
}
361+
else if(parallel_blocks >= 16)
362+
{
363+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16, type_K, type_V, use_logit_softcap>(ctx, dst);
364+
}
365+
else if(parallel_blocks >= 12)
366+
{
367+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12, type_K, type_V, use_logit_softcap>(ctx, dst);
368+
}
369+
else if(parallel_blocks >= 8)
370+
{
371+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8, type_K, type_V, use_logit_softcap>(ctx, dst);
372+
}
373+
else
374+
{
375+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4, type_K, type_V, use_logit_softcap>(ctx, dst);
376+
}
318377
}
319378
return;
320379
}

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
297297
}
298298

299299
const int gqa_ratio = Q->ne[2] / K->ne[2];
300-
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
301-
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
302-
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
300+
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
303301
if (prec == GGML_PREC_DEFAULT) {
304302
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
305303
return;

0 commit comments

Comments
 (0)