-
Notifications
You must be signed in to change notification settings - Fork 12.2k
CUDA: use async data loading for FlashAttention #11894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: use async data loading for FlashAttention #11894
Conversation
Hi @JohannesGaessler thank you for this PR. I think it fixes a regression that has started since #11583 was merged. Was just about to create a new issue/bug report here but you beat me to it, and I'll still note my findings for posterity even thought it (appears?) to have been fixed just in case. So basically since #11583 we have had users complaining about incoherent outputs at high contexts with Mistral Small 3 based models, but surprisingly only on a desktop RTX 4090 with flash attention enabled. I myself use a laptop RTX 4090 which I thought had the exact same compute capability, but was never able to repro it locally, nor was anyone else running any older cards, we tested quite a few. Finally I rented a desktop rtx 4090 on runpod, and I was able to reproduce the incoherent outputs issue. I knew it was a problem with the new FA implementation because avoiding I just couldn't understand why it was only limited to Desktop RTX 4090s, and was considering swapping back to the old FA wholesale. But it seems like your new PR solves it, at least from my initial tests. I'll update if I notice reports of this issue again. |
laptop 4090(AD103) is desktop 4080 chip downclocked so it might be more variant specific. you can try that on 4070ish to see if it behave any different:) |
The reason the desktop 4090 specifically can produce incorrect results is the number of streaming multiprocessors. I changed the FlashAttention kernels to use stream-k decomposition which in most cases needs a fixup after the regular kernel because it can work on fractional tiles that may need to be combined. However, if the number of tiles can be exactly divided by the number of SMs then the fixup can be skipped. In my original PR I had swapped the arguments for the modulo operator so GPUs with an SM count that is a power of 2 could for some configurations skip the fixup incorrectly. And a desktop 4090 has 128 SMs, a laptop 4090 has 76 SMs. |
Co-authored-by: Diego Devesa <[email protected]>
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <[email protected]>
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <[email protected]>
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <[email protected]>
This PR adds the use of asynchronous data copies to the CUDA FlashAttention kernels which enables simultaneous data loading and computation. This feature needs Ampere or newer. I exposed the related PTX instructions via a simple API in a new header
cp-async.cuh
because I don't like the existing CUDA API; I think it obfuscates what the hardware is actually doing and it needlessly causes compilation failures when included with the wrong compute capabilities. While I was working on this I refactored the API inmma.cuh
to use the same fundamental data structures for all matrix tiles. This is more convenient for kernels where multiple matrix multiplications are being chained and the output tiles of one operation are the input tiles of another operation.This PR also fixes a bug on master where the condition for when the stream-k fixup for FlashAttention can be skipped was wrong. In practice it is very unlikely though that this bug resulted in incorrect results due to the specific number of streaming multiprocessors found on contemporary NVIDIA GPUs.
Performance changes