Skip to content

fix: attempt forward on flash attn2 to check hardware support #2335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 5, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Jul 30, 2024

This PR attempt to execute a forward pass to validate that flash attention 2 works on the current hardware. Prior to this change import flash_attn_2_cuda could load but flash_attn_2_cuda.varlen_fwd fail with FlashAttention only supports Ampere GPUs or newer.

This change causes the runtime error when the library is loaded which will set V2 to False and avoid using flash attn 2 in the forward pass

@@ -173,6 +174,41 @@ def paged_attention(
try:
import flash_attn_2_cuda

# try forwarding to see if it works with all dummy inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be easier to require the minimum needed CUDA capability? We could even skip the import altogether if the hardware doesn't have the right capability.

Copy link
Collaborator

@Narsil Narsil Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it would be simpler. We already have is_sm75 which was there probably for that reason.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I agree thats a better method. Updated to use is_ampere_or_newer = major >= 8 and minor >= 0 in the latest commits and avoid trying the forward pass

@@ -254,9 +290,11 @@ def attention(
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also not change that.

It's important to hard crash and not silently ignored.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in the latest commit to still throw when an invalid value is passed

if (
(sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING
and max_input_tokens > sliding_window
and is_max_input_within_sliding_window
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm misreading this, but shouldn't the exception be raised when the max input is not inside the window size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this line was correct before.

The only thing that needs to happen is that if max_input_tokens <= sliding_window we can set sliding_window to -1 and forget about the windowing.

The code changes need to be much less.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to set sliding_window to -1 if max_input_tokens <= sliding_window in the latest commit

self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
torch.tensor(self.max_past, device=weights.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this try to create a tensor with the None value if windowing is not supported? Same below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh yes you are right 😅, fixed in the latest changes to avoid this all together

@drbh drbh merged commit 215ed3a into main Aug 5, 2024
9 checks passed
@drbh drbh deleted the validate-flash-attn2-on-arch branch August 5, 2024 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants