Skip to content

CUDA: use async data loading for FlashAttention #11894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds the use of asynchronous data copies to the CUDA FlashAttention kernels which enables simultaneous data loading and computation. This feature needs Ampere or newer. I exposed the related PTX instructions via a simple API in a new header cp-async.cuh because I don't like the existing CUDA API; I think it obfuscates what the hardware is actually doing and it needlessly causes compilation failures when included with the wrong compute capabilities. While I was working on this I refactored the API in mma.cuh to use the same fundamental data structures for all matrix tiles. This is more convenient for kernels where multiple matrix multiplications are being chained and the output tiles of one operation are the input tiles of another operation.

This PR also fixes a bug on master where the condition for when the stream-k fixup for FlashAttention can be skipped was wrong. In practice it is very unlikely though that this bug resulted in incorrect results due to the specific number of streaming multiprocessors found on contemporary NVIDIA GPUs.

Performance changes
GPU Model Microbatch size Test t/s master t/s cuda-fa-mma-17 Speedup
RTX 3090 gemma 2B F16 8 pp16384 884.99 960.45 1.09
RTX 3090 gemma 2B F16 16 pp16384 1723.90 1751.24 1.02
RTX 3090 gemma 2B F16 32 pp16384 2930.83 2833.06 0.97
RTX 3090 gemma 2B F16 64 pp16384 3919.85 3783.78 0.97
RTX 3090 gemma 2B F16 128 pp16384 7502.87 7534.81 1.00
RTX 3090 gemma 2B F16 256 pp16384 9887.99 9922.51 1.00
RTX 3090 gemma 2B F16 512 pp16384 11016.88 11048.05 1.00
RTX 3090 gemma 2B F16 1024 pp16384 11479.26 11647.72 1.01
RTX 3090 gemma 2B F16 2048 pp16384 11766.60 11825.53 1.01
RTX 3090 gemma 2B F16 4096 pp16384 11775.10 11840.24 1.01
RTX 3090 gemma 2B F16 8192 pp16384 11770.49 11834.37 1.01
RTX 3090 gemma 2B F16 16384 pp16384 11797.73 11853.17 1.00
RTX 3090 llama 8B Q4_0 8 pp16384 423.95 433.73 1.02
RTX 3090 llama 8B Q4_0 16 pp16384 854.51 857.10 1.00
RTX 3090 llama 8B Q4_0 32 pp16384 1465.20 1458.56 1.00
RTX 3090 llama 8B Q4_0 64 pp16384 2247.54 2239.60 1.00
RTX 3090 llama 8B Q4_0 128 pp16384 2855.43 2845.70 1.00
RTX 3090 llama 8B Q4_0 256 pp16384 3298.58 3565.78 1.08
RTX 3090 llama 8B Q4_0 512 pp16384 3461.29 3694.70 1.07
RTX 3090 llama 8B Q4_0 1024 pp16384 3501.73 3818.37 1.09
RTX 3090 llama 8B Q4_0 2048 pp16384 3465.74 3816.77 1.10
RTX 3090 llama 8B Q4_0 4096 pp16384 3477.07 3830.18 1.10
RTX 3090 llama 8B Q4_0 8192 pp16384 3516.95 3835.93 1.09
RTX 3090 llama 8B Q4_0 16384 pp16384 3413.82 3660.37 1.07
RTX 3090 phi2 3B F16 8 pp16384 535.35 559.82 1.05
RTX 3090 phi2 3B F16 16 pp16384 1045.29 1069.07 1.02
RTX 3090 phi2 3B F16 32 pp16384 1944.48 1906.19 0.98
RTX 3090 phi2 3B F16 64 pp16384 3257.62 3172.20 0.97
RTX 3090 phi2 3B F16 128 pp16384 4727.21 4608.92 0.97
RTX 3090 phi2 3B F16 256 pp16384 5465.76 6035.44 1.10
RTX 3090 phi2 3B F16 512 pp16384 5834.88 6498.62 1.11
RTX 3090 phi2 3B F16 1024 pp16384 5728.60 6542.47 1.14
RTX 3090 phi2 3B F16 2048 pp16384 5710.91 6725.69 1.18
RTX 3090 phi2 3B F16 4096 pp16384 5708.27 6719.02 1.18
RTX 3090 phi2 3B F16 8192 pp16384 5712.93 6725.94 1.18
RTX 3090 phi2 3B F16 16384 pp16384 5676.58 6731.76 1.19
RTX 4090 gemma 2B F16 8 pp16384 1204.39 1209.22 1.00
RTX 4090 gemma 2B F16 16 pp16384 2439.13 2168.01 0.89
RTX 4090 gemma 2B F16 32 pp16384 4686.86 3690.27 0.79
RTX 4090 gemma 2B F16 64 pp16384 7842.68 5194.76 0.66
RTX 4090 gemma 2B F16 128 pp16384 13437.09 11235.78 0.84
RTX 4090 gemma 2B F16 256 pp16384 20299.78 19070.07 0.94
RTX 4090 gemma 2B F16 512 pp16384 25635.41 25250.59 0.98
RTX 4090 gemma 2B F16 1024 pp16384 24842.82 24671.27 0.99
RTX 4090 gemma 2B F16 2048 pp16384 24328.35 24944.02 1.03
RTX 4090 gemma 2B F16 4096 pp16384 24343.09 24838.79 1.02
RTX 4090 gemma 2B F16 8192 pp16384 24244.10 24918.96 1.03
RTX 4090 gemma 2B F16 16384 pp16384 24280.23 24897.67 1.03
RTX 4090 llama 8B Q4_0 8 pp16384 886.31 904.86 1.02
RTX 4090 llama 8B Q4_0 16 pp16384 1590.18 1533.52 0.96
RTX 4090 llama 8B Q4_0 32 pp16384 2823.95 2694.92 0.95
RTX 4090 llama 8B Q4_0 64 pp16384 4963.35 4331.06 0.87
RTX 4090 llama 8B Q4_0 128 pp16384 6873.26 6517.14 0.95
RTX 4090 llama 8B Q4_0 256 pp16384 8614.69 8414.13 0.98
RTX 4090 llama 8B Q4_0 512 pp16384 9350.85 9609.05 1.03
RTX 4090 llama 8B Q4_0 1024 pp16384 9363.54 9615.67 1.03
RTX 4090 llama 8B Q4_0 2048 pp16384 8976.52 9201.50 1.03
RTX 4090 llama 8B Q4_0 4096 pp16384 8971.77 9191.37 1.02
RTX 4090 llama 8B Q4_0 8192 pp16384 8975.75 9185.91 1.02
RTX 4090 llama 8B Q4_0 16384 pp16384 8971.71 9201.87 1.03
RTX 4090 phi2 3B F16 8 pp16384 690.30 700.12 1.01
RTX 4090 phi2 3B F16 16 pp16384 1367.85 1315.42 0.96
RTX 4090 phi2 3B F16 32 pp16384 2677.82 2528.37 0.94
RTX 4090 phi2 3B F16 64 pp16384 5014.41 4357.40 0.87
RTX 4090 phi2 3B F16 128 pp16384 8427.73 7898.26 0.94
RTX 4090 phi2 3B F16 256 pp16384 12136.08 11752.74 0.97
RTX 4090 phi2 3B F16 512 pp16384 13649.55 14318.49 1.05
RTX 4090 phi2 3B F16 1024 pp16384 15067.15 15920.39 1.06
RTX 4090 phi2 3B F16 2048 pp16384 13654.12 14460.66 1.06
RTX 4090 phi2 3B F16 4096 pp16384 13662.36 14443.62 1.06
RTX 4090 phi2 3B F16 8192 pp16384 13664.72 14455.61 1.06
RTX 4090 phi2 3B F16 16384 pp16384 13663.50 14460.25 1.06

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Feb 15, 2025
@LostRuins
Copy link
Collaborator

Hi @JohannesGaessler thank you for this PR. I think it fixes a regression that has started since #11583 was merged. Was just about to create a new issue/bug report here but you beat me to it, and I'll still note my findings for posterity even thought it (appears?) to have been fixed just in case.

So basically since #11583 we have had users complaining about incoherent outputs at high contexts with Mistral Small 3 based models, but surprisingly only on a desktop RTX 4090 with flash attention enabled. I myself use a laptop RTX 4090 which I thought had the exact same compute capability, but was never able to repro it locally, nor was anyone else running any older cards, we tested quite a few.

Finally I rented a desktop rtx 4090 on runpod, and I was able to reproduce the incoherent outputs issue.

I knew it was a problem with the new FA implementation because avoiding ggml_cuda_flash_attn_ext_mma_f16 with
LostRuins@7502af9 and reverting to ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); worked perfectly fine in all scenarios.

I just couldn't understand why it was only limited to Desktop RTX 4090s, and was considering swapping back to the old FA wholesale.

But it seems like your new PR solves it, at least from my initial tests. I'll update if I notice reports of this issue again.

@sorasoras
Copy link

sorasoras commented Feb 17, 2025

Hi @JohannesGaessler thank you for this PR. I think it fixes a regression that has started since #11583 was merged. Was just about to create a new issue/bug report here but you beat me to it, and I'll still note my findings for posterity even thought it (appears?) to have been fixed just in case.

So basically since #11583 we have had users complaining about incoherent outputs at high contexts with Mistral Small 3 based models, but surprisingly only on a desktop RTX 4090 with flash attention enabled. I myself use a laptop RTX 4090 which I thought had the exact same compute capability, but was never able to repro it locally, nor was anyone else running any older cards, we tested quite a few.

Finally I rented a desktop rtx 4090 on runpod, and I was able to reproduce the incoherent outputs issue.

I knew it was a problem with the new FA implementation because avoiding ggml_cuda_flash_attn_ext_mma_f16 with
LostRuins@7502af9 and reverting to ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); worked perfectly fine in all scenarios.

I just couldn't understand why it was only limited to Desktop RTX 4090s, and was considering swapping back to the old FA wholesale.

But it seems like your new PR solves it, at least from my initial tests. I'll update if I notice reports of this issue again.

laptop 4090(AD103) is desktop 4080 chip downclocked so it might be more variant specific. you can try that on 4070ish to see if it behave any different:)

@JohannesGaessler
Copy link
Collaborator Author

The reason the desktop 4090 specifically can produce incorrect results is the number of streaming multiprocessors. I changed the FlashAttention kernels to use stream-k decomposition which in most cases needs a fixup after the regular kernel because it can work on fractional tiles that may need to be combined. However, if the number of tiles can be exactly divided by the number of SMs then the fixup can be skipped. In my original PR I had swapped the arguments for the modulo operator so GPUs with an SM count that is a power of 2 could for some configurations skip the fixup incorrectly. And a desktop 4090 has 128 SMs, a laptop 4090 has 76 SMs.

@JohannesGaessler JohannesGaessler merged commit 73e2ed3 into ggml-org:master Feb 17, 2025
42 checks passed
orca-zhang pushed a commit to orca-zhang/llama.cpp that referenced this pull request Feb 26, 2025
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Feb 26, 2025
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <[email protected]>
mglambda pushed a commit to mglambda/llama.cpp that referenced this pull request Mar 8, 2025
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants