Skip to content

CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case #12183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 61 additions & 27 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
*dst = dst_val / rowsum;
}

template<int D, int parallel_blocks> // D == head size
template<int D> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst) {
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
dst += D * gridDim.y*blockIdx.x;
float * __restrict__ dst,
const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
dst += D * gridDim.z*blockIdx.x;

const int tid = threadIdx.x;
__builtin_assume(tid < D);

__shared__ float2 meta[parallel_blocks];
extern __shared__ float2 meta[];
if (tid < 2*parallel_blocks) {
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
}

__syncthreads();

float kqmax = meta[0].x;
#pragma unroll
for (int l = 1; l < parallel_blocks; ++l) {
kqmax = max(kqmax, meta[l].x);
}

float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
#pragma unroll
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
const float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;

VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}

dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
}

static void on_no_fattn_vec_case(const int D) {
Expand All @@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
}
}

// parallel_blocks == 0 is stream-k decomposition
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
template <int D, int ncols1, int ncols2, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
const int warp_size = WARP_SIZE
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;

Expand Down Expand Up @@ -748,12 +745,14 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}

int parallel_blocks = 1;

const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];

const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
if (stream_k) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
Expand All @@ -769,9 +768,43 @@ void launch_fattn(

dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
blocks_num.z = Q->ne[3];
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.

int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));

// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);

// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);

// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
// Test whether parallel_blocks can be set to a higher value for better efficiency.
const int blocks_per_wave = nsm * max_blocks_per_sm;
int nwaves_best = 0;
int efficiency_percent_best = 0;
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
const int nblocks_total = ntiles_total * parallel_blocks_test;
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);

// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
break;
}

if (efficiency_percent > efficiency_percent_best) {
nwaves_best = nwaves;
efficiency_percent_best = efficiency_percent;
parallel_blocks = parallel_blocks_test;
}
}

blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];

if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
Expand Down Expand Up @@ -803,7 +836,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
Expand All @@ -815,7 +848,7 @@ void launch_fattn(
);
CUDA_CHECK(cudaGetLastError());

if constexpr (parallel_blocks == 0) {
if (stream_k) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
Expand All @@ -824,13 +857,14 @@ void launch_fattn(
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
} else if constexpr (parallel_blocks > 1) {
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);

flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
flash_attn_combine_results<D>
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
}
CUDA_CHECK(cudaGetLastError());
}
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
}

launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
}


Expand Down
63 changes: 24 additions & 39 deletions ggml/src/ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#define FATTN_KQ_STRIDE_TILE_F16 64

template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
Expand Down Expand Up @@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.

const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;

const int stride_KV2 = nb11 / sizeof(half2);

const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);

static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
Expand Down Expand Up @@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(

__syncthreads();

const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:

half kqmax_new[ncols/nwarps];
Expand Down Expand Up @@ -271,40 +269,42 @@ static __global__ void flash_attn_tile_ext_f16(
const int i0 = i00 + 2*threadIdx.x;

half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
}

if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}

template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
Expand All @@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten

if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}

constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}
Loading
Loading