Skip to content

CUDA: use async data loading for FlashAttention #11894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons

#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000

// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
Expand Down Expand Up @@ -199,6 +200,10 @@ typedef float2 dfloat2;
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
Expand Down Expand Up @@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}

static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}

static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
Expand Down
46 changes: 46 additions & 0 deletions ggml/src/ggml-cuda/cp-async.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Simplified API for asynchronous data loading.

#include "common.cuh"

// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
template <int preload>
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
#ifdef CP_ASYNC_AVAILABLE
#if CUDART_VERSION >= 11040
if (preload == 256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
GGML_UNUSED(dst);
GGML_UNUSED(src);
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}

// Makes each thread wait until its asynchronous data copies are done.
// This does NOT provide any additional synchronization.
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
static __device__ __forceinline__ void cp_async_wait_all() {
#ifdef CP_ASYNC_AVAILABLE
asm volatile("cp.async.wait_all;");
#else
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
15 changes: 9 additions & 6 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,9 @@ void launch_fattn(

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
Expand Down Expand Up @@ -768,13 +770,14 @@ void launch_fattn(
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);

const int nblocks_stream_k = 2*nsm;

blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;

blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;

Expand Down Expand Up @@ -827,7 +830,7 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());

if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;

Expand Down
Loading
Loading