Skip to content

Commit b6e067b

Browse files
committed
CUDA: Improve flash decoding kernel occupancy for BS=1 case
Adds the following optimizations to the CUDA flash decoding code: - Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value. - Prefer vector flash attention kernels over MMA kernel for BS=1 This results in upto 15% perf improvement in gen phase throughput for large seq lengths. Issue: #12182
1 parent f6711ce commit b6e067b

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ void launch_fattn(
703703
GGML_ASSERT(Q->ne[3] == 1);
704704

705705
GGML_ASSERT(stream_k || ncols2 == 1);
706-
const int parallel_blocks = Q->ne[1] <= ncols1 ? 4 : 1;
706+
const bool use_parallel_blocks = !stream_k && (Q->ne[1] <= ncols1) ? true : false;
707707

708708
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
709709

@@ -756,6 +756,8 @@ void launch_fattn(
756756
nb23 = nb23*bs*sizeof(half)/ts;
757757
}
758758

759+
int parallel_blocks = 1;
760+
759761
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
760762
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
761763

@@ -777,6 +779,21 @@ void launch_fattn(
777779

778780
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
779781
} else {
782+
if (use_parallel_blocks) {
783+
const int num_blocks_base = ntiles_x*Q->ne[2]*Q->ne[3];
784+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
785+
const int seqlen_tiles = (K->ne[1] + D - 1) / D;
786+
787+
// Determine the number of active blocks per SM
788+
int numActiveBlocks = 1;
789+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numActiveBlocks, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
790+
791+
// we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
792+
// this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
793+
// If there are not enough tiles to process, we can reduce the number of blocks
794+
parallel_blocks = std::max(std::min((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles), 1);
795+
}
796+
780797
blocks_num.x = ntiles_x;
781798
blocks_num.y = parallel_blocks;
782799
blocks_num.z = Q->ne[2]*Q->ne[3];

ggml/src/ggml-cuda/fattn.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,6 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
244244
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245245
const ggml_tensor * KQV = dst;
246246
const ggml_tensor * Q = dst->src[0];
247-
const ggml_tensor * K = dst->src[1];
248-
const ggml_tensor * V = dst->src[2];
249-
const ggml_tensor * mask = dst->src[3];
250247

251248
ggml_cuda_set_device(ctx.device);
252249
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -296,10 +293,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
296293
return;
297294
}
298295

299-
const int gqa_ratio = Q->ne[2] / K->ne[2];
300-
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
301-
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
302-
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
296+
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0) {
303297
if (prec == GGML_PREC_DEFAULT) {
304298
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
305299
return;

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
#define cudaGraph_t hipGraph_t
130130
#define cudaStream_t hipStream_t
131131
#define cudaSuccess hipSuccess
132+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
132133
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
133134
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134135
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,6 @@
133133
#define cudaKernelNodeParams musaKernelNodeParams
134134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135135
#define cudaStreamEndCapture musaStreamEndCapture
136+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
136137

137138
typedef mt_bfloat16 nv_bfloat16;

0 commit comments

Comments
 (0)