Skip to content

CUDA: optimize FA for GQA + large batches #12014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/cp-async.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
Expand Down
122 changes: 53 additions & 69 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}

// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__

template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);

const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
constexpr int ncols = ncols1*ncols2;

const int bidx0 = blockIdx.x;
const int j = blockIdx.y;
const int c = blockIdx.z;
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;

const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);

const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;

const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;

const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
Expand All @@ -548,59 +546,53 @@ static __global__ void flash_attn_stream_k_fixup(
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;

dst += jt*ncols*ne02*D + channel*D;
if (jt*ncols1 + j >= ne01) {
return;
}

// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;

const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
float max_val = 0.0f;
float rowsum = 0.0f;
{
dst_val = *dst;

const float2 tmp = dst_fixup[bidx0*ncols + jc];
max_val = tmp.x;
rowsum = tmp.y;
}

// Iterate over previous blocks and compute the combined results.
// All CUDA blocks that get here must have a previous block that needs a fixup.
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}

#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];

const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];

// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val, tmp.x);

const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float diff_val = max_val - max_val_new;
const float diff_add = tmp.x - max_val_new;

const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;

dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
dst_val = scale_val*dst_val + scale_add*dst_add;
rowsum = scale_val*rowsum + scale_add*tmp.y;

max_val[j] = max_val_new;
}
max_val = max_val_new;

// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
Expand All @@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
}

// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
*dst = dst_val / rowsum;
}

#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__

template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
Expand Down Expand Up @@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
}

// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
constexpr int ncols = ncols1 * ncols2;

const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
Expand Down Expand Up @@ -763,25 +747,26 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}

const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];

const dim3 block_dim(WARP_SIZE, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);

const int nblocks_stream_k = 2*nsm;
const int nblocks_stream_k = max_blocks;

const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;

blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;

dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
Expand All @@ -793,7 +778,6 @@ void launch_fattn(
}
}


float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
Expand Down Expand Up @@ -832,9 +816,9 @@ void launch_fattn(
if constexpr (parallel_blocks == 0) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};

flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
Expand Down
Loading
Loading