Skip to content

Commit 4f91d57

Browse files
JohannesGaesslermostlyuseful
authored andcommitted
CUDA: optimize FA for GQA + large batches (ggml-org#12014)
1 parent 98a85b4 commit 4f91d57

32 files changed

+939
-410
lines changed

ggml/src/ggml-cuda/cp-async.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
2424
} else
2525
#endif // CUDART_VERSION >= 11040
2626
{
27-
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
27+
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
2828
: : "r"(dst), "l"(src));
2929
}
3030
#else

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 53 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516
nullptr;
517517
}
518518

519-
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
520-
#ifdef __clang__
521-
#pragma clang diagnostic push
522-
#pragma clang diagnostic ignored "-Wpass-failed"
523-
#endif // __clang__
524-
525-
template<int D, int ncols, int KQ_stride> // D == head size
526-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
519+
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
527520
__launch_bounds__(D, 1)
528-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
529521
static __global__ void flash_attn_stream_k_fixup(
530522
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
531-
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
532-
533-
const int iter_k = ne11 / KQ_stride;
534-
const int iter_j = (ne01 + (ncols - 1)) / ncols;
523+
constexpr int ncols = ncols1*ncols2;
535524

536525
const int bidx0 = blockIdx.x;
526+
const int j = blockIdx.y;
527+
const int c = blockIdx.z;
528+
const int jc = j*ncols2 + c;
529+
const int tid = threadIdx.x;
530+
531+
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
532+
533+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
534+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
537535

538-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
539-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
536+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
537+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
540538

541539
const bool did_not_have_any_data = kbc0 == kbc0_stop;
542540
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,59 +546,53 @@ static __global__ void flash_attn_stream_k_fixup(
548546
const int channel = kbc0 / (iter_k*iter_j);
549547
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550548

551-
dst += jt*ncols*ne02*D + channel*D;
549+
if (jt*ncols1 + j >= ne01) {
550+
return;
551+
}
552552

553-
// Load the partial result that needs a fixup:
554-
float dst_val[ncols] = {0.0f};
555-
float max_val[ncols] = {0.0f};
556-
float rowsum[ncols] = {0.0f};
557-
#pragma unroll
558-
for (int j = 0; j < ncols; ++j) {
559-
if (jt*ncols + j >= ne01) {
560-
break;
561-
}
562-
dst_val[j] = dst[j*ne02*D + threadIdx.x];
553+
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
563554

564-
const float2 tmp = dst_fixup[bidx0*ncols + j];
565-
max_val[j] = tmp.x;
566-
rowsum[j] = tmp.y;
555+
// Load the partial result that needs a fixup:
556+
float dst_val = 0.0f;
557+
float max_val = 0.0f;
558+
float rowsum = 0.0f;
559+
{
560+
dst_val = *dst;
561+
562+
const float2 tmp = dst_fixup[bidx0*ncols + jc];
563+
max_val = tmp.x;
564+
rowsum = tmp.y;
567565
}
568566

569567
// Iterate over previous blocks and compute the combined results.
570568
// All CUDA blocks that get here must have a previous block that needs a fixup.
571569
int bidx = bidx0 - 1;
572570
int kbc_stop = kbc0;
573571
while(true) {
574-
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
572+
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
575573
if (kbc == kbc_stop) { // Did not have any data.
576574
bidx--;
577575
kbc_stop = kbc;
578576
continue;
579577
}
580578

581-
#pragma unroll
582-
for (int j = 0; j < ncols; ++j) {
583-
if (jt*ncols + j >= ne01) {
584-
break;
585-
}
586-
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
579+
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
587580

588-
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
581+
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
589582

590-
// Scale the current and new value accumulators depending on the max. values.
591-
const float max_val_new = fmaxf(max_val[j], tmp.x);
583+
// Scale the current and new value accumulators depending on the max. values.
584+
const float max_val_new = fmaxf(max_val, tmp.x);
592585

593-
const float diff_val = max_val[j] - max_val_new;
594-
const float diff_add = tmp.x - max_val_new;
586+
const float diff_val = max_val - max_val_new;
587+
const float diff_add = tmp.x - max_val_new;
595588

596-
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
597-
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
589+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
590+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
598591

599-
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
600-
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
592+
dst_val = scale_val*dst_val + scale_add*dst_add;
593+
rowsum = scale_val*rowsum + scale_add*tmp.y;
601594

602-
max_val[j] = max_val_new;
603-
}
595+
max_val = max_val_new;
604596

605597
// If this block started in a previous tile we are done and don't need to combine additional partial results.
606598
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
611603
}
612604

613605
// Write back final result:
614-
#pragma unroll
615-
for (int j = 0; j < ncols; ++j) {
616-
if (jt*ncols + j >= ne01) {
617-
return;
618-
}
619-
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
620-
}
606+
*dst = dst_val / rowsum;
621607
}
622608

623-
#ifdef __clang__
624-
#pragma clang diagnostic pop
625-
#endif // __clang__
626-
627609
template<int D, int parallel_blocks> // D == head size
628610
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
629611
__launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
690672
}
691673

692674
// parallel_blocks == 0 is stream-k decomposition
693-
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
675+
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
694676
void launch_fattn(
695677
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
696678
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
697679
) {
680+
constexpr int ncols = ncols1 * ncols2;
681+
698682
const ggml_tensor * Q = dst->src[0];
699683
const ggml_tensor * K = dst->src[1];
700684
const ggml_tensor * V = dst->src[2];
@@ -763,25 +747,26 @@ void launch_fattn(
763747
nb23 = nb23*bs*sizeof(half)/ts;
764748
}
765749

766-
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
767-
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
750+
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751+
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
768752

769753
const dim3 block_dim(WARP_SIZE, nwarps, 1);
770754
dim3 blocks_num;
771755
if (parallel_blocks == 0) {
772756
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
773-
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
774-
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
757+
const int max_blocks = 2*nsm;
758+
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
759+
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
775760

776-
const int nblocks_stream_k = 2*nsm;
761+
const int nblocks_stream_k = max_blocks;
777762

778-
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
763+
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
779764

780765
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
781766
blocks_num.y = 1;
782767
blocks_num.z = 1;
783768

784-
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
769+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
785770
} else {
786771
blocks_num.x = parallel_blocks*ntiles_x;
787772
blocks_num.y = Q->ne[2];
@@ -793,7 +778,6 @@ void launch_fattn(
793778
}
794779
}
795780

796-
797781
float scale = 1.0f;
798782
float max_bias = 0.0f;
799783
float logit_softcap = 0.0f;
@@ -832,9 +816,9 @@ void launch_fattn(
832816
if constexpr (parallel_blocks == 0) {
833817
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
834818
const dim3 block_dim_combine(D, 1, 1);
835-
const dim3 blocks_num_combine = blocks_num;
819+
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
836820

837-
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
821+
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
838822
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
839823
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
840824
}

0 commit comments

Comments
 (0)