Skip to content

Commit 864a0b6

Browse files
CUDA: use mma PTX instructions for FlashAttention (#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <[email protected]>
1 parent 84ec8a5 commit 864a0b6

29 files changed

+2057
-997
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ ifdef GGML_RPC
596596
OBJ_GGML_EXT += ggml/src/ggml-rpc.o
597597
endif # GGML_RPC
598598

599-
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu))
599+
OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
600600
OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
601601

602602
ifdef GGML_CUDA_FA_ALL_QUANTS

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ extern "C" {
17751775
struct ggml_tensor * a,
17761776
int k);
17771777

1778-
#define GGML_KQ_MASK_PAD 32
1778+
#define GGML_KQ_MASK_PAD 64
17791779

17801780
// q: [n_embd, n_batch, n_head, 1]
17811781
// k: [n_embd, n_kv, n_head_kv, 1]

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
2828
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
2929

3030
file(GLOB GGML_SOURCES_CUDA "*.cu")
31-
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
31+
file(GLOB SRCS "template-instances/fattn-mma*.cu")
3232
list(APPEND GGML_SOURCES_CUDA ${SRCS})
3333
file(GLOB SRCS "template-instances/mmq*.cu")
3434
list(APPEND GGML_SOURCES_CUDA ${SRCS})

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
148148
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
149149

150150
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
151-
#define INT8_MMA_AVAILABLE
151+
#define NEW_MMA_AVAILABLE
152152
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
153153

154154
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
159159
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
160160
}
161161

162+
// Any FP16 tensor cores are available.
162163
static constexpr bool fp16_mma_available(const int cc) {
163164
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
164165
}
165166

166-
static constexpr bool int8_mma_available(const int cc) {
167+
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
168+
static constexpr bool new_mma_available(const int cc) {
167169
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
168170
}
169171

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 154 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516
nullptr;
517517
}
518518

519+
template<int D, int ncols, int KQ_stride> // D == head size
520+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521+
__launch_bounds__(D, 1)
522+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523+
static __global__ void flash_attn_stream_k_fixup(
524+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
525+
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
526+
527+
const int iter_k = ne11 / KQ_stride;
528+
const int iter_j = (ne01 + (ncols - 1)) / ncols;
529+
530+
const int bidx0 = blockIdx.x;
531+
532+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
533+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
534+
535+
const bool did_not_have_any_data = kbc0 == kbc0_stop;
536+
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
537+
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
538+
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
539+
return;
540+
}
541+
542+
const int channel = kbc0 / (iter_k*iter_j);
543+
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
544+
545+
dst += jt*ncols*ne02*D + channel*D;
546+
547+
// Load the partial result that needs a fixup:
548+
float dst_val[ncols] = {0.0f};
549+
float max_val[ncols] = {0.0f};
550+
float rowsum[ncols] = {0.0f};
551+
#pragma unroll
552+
for (int j = 0; j < ncols; ++j) {
553+
if (jt*ncols + j >= ne01) {
554+
break;
555+
}
556+
dst_val[j] = dst[j*ne02*D + threadIdx.x];
557+
558+
const float2 tmp = dst_fixup[bidx0*ncols + j];
559+
max_val[j] = tmp.x;
560+
rowsum[j] = tmp.y;
561+
}
562+
563+
// Iterate over previous blocks and compute the combined results.
564+
// All CUDA blocks that get here must have a previous block that needs a fixup.
565+
int bidx = bidx0 - 1;
566+
int kbc_stop = kbc0;
567+
while(true) {
568+
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
569+
if (kbc == kbc_stop) { // Did not have any data.
570+
bidx--;
571+
kbc_stop = kbc;
572+
continue;
573+
}
574+
575+
#pragma unroll
576+
for (int j = 0; j < ncols; ++j) {
577+
if (jt*ncols + j >= ne01) {
578+
break;
579+
}
580+
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
581+
582+
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
583+
584+
// Scale the current and new value accumulators depending on the max. values.
585+
const float max_val_new = fmaxf(max_val[j], tmp.x);
586+
587+
const float diff_val = max_val[j] - max_val_new;
588+
const float diff_add = tmp.x - max_val_new;
589+
590+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
591+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
592+
593+
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
594+
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
595+
596+
max_val[j] = max_val_new;
597+
}
598+
599+
// If this block started in a previous tile we are done and don't need to combine additional partial results.
600+
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
601+
break;
602+
}
603+
bidx--;
604+
kbc_stop = kbc;
605+
}
606+
607+
// Write back final result:
608+
#pragma unroll
609+
for (int j = 0; j < ncols; ++j) {
610+
if (jt*ncols + j >= ne01) {
611+
return;
612+
}
613+
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
614+
}
615+
}
616+
519617
template<int D, int parallel_blocks> // D == head size
520618
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521619
__launch_bounds__(D, 1)
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
581679
}
582680
}
583681

584-
template <int D, int parallel_blocks>
682+
// parallel_blocks == 0 is stream-k decomposition
683+
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
585684
void launch_fattn(
586685
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
587-
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
686+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
588687
) {
589688
const ggml_tensor * Q = dst->src[0];
590689
const ggml_tensor * K = dst->src[1];
@@ -603,20 +702,23 @@ void launch_fattn(
603702

604703
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
605704

705+
GGML_ASSERT(Q->ne[3] == 1);
706+
606707
ggml_cuda_pool & pool = ctx.pool();
607708
cudaStream_t main_stream = ctx.stream();
709+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
608710

609711
ggml_cuda_pool_alloc<half> K_f16(pool);
610712
ggml_cuda_pool_alloc<half> V_f16(pool);
611713
ggml_cuda_pool_alloc<float> dst_tmp(pool);
612714
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
613715

614-
char * K_data = (char *) K->data;
716+
const char * K_data = (const char *) K->data;
615717
size_t nb11 = K->nb[1];
616718
size_t nb12 = K->nb[2];
617719
size_t nb13 = K->nb[3];
618720

619-
char * V_data = (char *) V->data;
721+
const char * V_data = (const char *) V->data;
620722
size_t nb21 = V->nb[1];
621723
size_t nb22 = V->nb[2];
622724
size_t nb23 = V->nb[3];
@@ -649,39 +751,60 @@ void launch_fattn(
649751
nb23 = nb23*bs*sizeof(half)/ts;
650752
}
651753

652-
if (parallel_blocks > 1) {
653-
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
654-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
655-
}
754+
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
755+
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
656756

657757
const dim3 block_dim(WARP_SIZE, nwarps, 1);
658-
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
659-
const int shmem = 0;
758+
dim3 blocks_num;
759+
if (parallel_blocks == 0) {
760+
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
761+
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
762+
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
763+
const bool short_context = K->ne[1] < 4096;
764+
765+
const int nblocks_stream_k = 2*nsm;
766+
767+
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
768+
blocks_num.y = 1;
769+
blocks_num.z = 1;
770+
771+
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
772+
} else {
773+
blocks_num.x = parallel_blocks*ntiles_x;
774+
blocks_num.y = Q->ne[2];
775+
blocks_num.z = Q->ne[3];
776+
777+
if (parallel_blocks > 1) {
778+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
779+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
780+
}
781+
}
782+
660783

661784
float scale = 1.0f;
662785
float max_bias = 0.0f;
663786
float logit_softcap = 0.0f;
664787

665-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
666-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
667-
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
788+
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
789+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
790+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
668791

669792
if (logit_softcap != 0.0f) {
670793
scale /= logit_softcap;
671794
}
672795

673796
const uint32_t n_head = Q->ne[2];
674-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
797+
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
675798

676799
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
677800
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
678801

679-
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
802+
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
680803
(const char *) Q->data,
681804
K_data,
682805
V_data,
683806
mask ? ((const char *) mask->data) : nullptr,
684-
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
807+
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
685808
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
686809
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
687810
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -693,16 +816,22 @@ void launch_fattn(
693816
);
694817
CUDA_CHECK(cudaGetLastError());
695818

696-
if ((parallel_blocks) == 1) {
697-
return;
698-
}
819+
if constexpr (parallel_blocks == 0) {
820+
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
821+
const dim3 block_dim_combine(D, 1, 1);
822+
const dim3 blocks_num_combine = blocks_num;
699823

700-
const dim3 block_dim_combine(D, 1, 1);
701-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
702-
const int shmem_combine = 0;
824+
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
825+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
826+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
827+
}
828+
} else if constexpr (parallel_blocks > 1) {
829+
const dim3 block_dim_combine(D, 1, 1);
830+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
703831

704-
flash_attn_combine_results<D, parallel_blocks>
705-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
706-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
832+
flash_attn_combine_results<D, parallel_blocks>
833+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
834+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
835+
}
707836
CUDA_CHECK(cudaGetLastError());
708837
}

0 commit comments

Comments
 (0)