-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: attempt forward on flash attn2 to check hardware support #2335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -173,6 +174,41 @@ def paged_attention( | |||
try: | |||
import flash_attn_2_cuda | |||
|
|||
# try forwarding to see if it works with all dummy inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be easier to require the minimum needed CUDA capability? We could even skip the import altogether if the hardware doesn't have the right capability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it would be simpler. We already have is_sm75
which was there probably for that reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I agree thats a better method. Updated to use is_ampere_or_newer = major >= 8 and minor >= 0
in the latest commits and avoid trying the forward pass
@@ -254,9 +290,11 @@ def attention( | |||
softcap=None, | |||
): | |||
if window_size_left != -1: | |||
raise NotImplementedError( | |||
"window_size_left is only available with flash attn v2" | |||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also not change that.
It's important to hard crash and not silently ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in the latest commit to still throw when an invalid value is passed
if ( | ||
(sliding_window is not None and sliding_window != -1) | ||
and not SUPPORTS_WINDOWING | ||
and max_input_tokens > sliding_window | ||
and is_max_input_within_sliding_window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm misreading this, but shouldn't the exception be raised when the max input is not inside the window size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this line was correct before.
The only thing that needs to happen is that if max_input_tokens <= sliding_window
we can set sliding_window to -1 and forget about the windowing.
The code changes need to be much less.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to set sliding_window
to -1 if max_input_tokens <= sliding_window
in the latest commit
self.max_past_tensor = ( | ||
torch.tensor(config.sliding_window, device=weights.device) | ||
if self.max_past is not None | ||
torch.tensor(self.max_past, device=weights.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this try to create a tensor with the None
value if windowing is not supported? Same below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh yes you are right 😅, fixed in the latest changes to avoid this all together
This PR attempt to execute a forward pass to validate that flash attention 2 works on the current hardware. Prior to this change
import flash_attn_2_cuda
could load butflash_attn_2_cuda.varlen_fwd
fail withFlashAttention only supports Ampere GPUs or newer
.This change causes the runtime error when the library is loaded which will set
V2
to False and avoid using flash attn 2 in the forward pass