Skip to content

Improve support for GPUs with capability < 8 #2575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 27, 2024
Merged

Improve support for GPUs with capability < 8 #2575

merged 4 commits into from
Sep 27, 2024

Conversation

danieldk
Copy link
Member

@danieldk danieldk commented Sep 26, 2024

What does this PR do?

Improve support for GPUs with capability < 8

  • For models that cannot use flashinfer, use flash-attn v1 + paged attention for models with a compute capability older than 8.
  • Disable prefix caching when using paged attention.
  • When using flash-attn v1, pass the key/value, rather than the cache, since v1 cannot use block tables.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.
@danieldk danieldk mentioned this pull request Sep 27, 2024
7 tasks
@danieldk danieldk marked this pull request as ready for review September 27, 2024 09:21
@@ -65,6 +66,7 @@ fn get_config(
}

fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = *gpu::COMPUTE_CAPABILITY;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just get_cuda_capability() ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to start a slow GIL if it's not needed.

Copy link
Member Author

@danieldk danieldk Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. This will only evaluate the compute capability once, it's just a cheap Rust lock. If we'd call get_cuda_capability, we'd call into Python land every time we call get_cuda_capability (only once now).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes yes I understand, but since we're only calling once there's no need, but it's really totally not important.

Comment on lines 123 to 131
let prefix_caching = if attention == "paged"
&& prefix_caching.is_none()
&& compute_capability.is_some()
{
tracing::info!("Disabling prefix caching because it is not supported with 'flashinfer'");
"false".to_string()
} else {
prefix_caching.unwrap_or("true".to_string())
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we move upstairs ? We're not allowed to change anything if the user sets it directly so this line must always be unwraps_or.
And in the previous code every modification must be protected to check that the thing is not None (there might even be a clearer method for it on Option.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an unwrap_or? The first branch of the conditional only fires when prefix_caching is None, so the user has set something. Don't know why I added the compute_capability.is_some() though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll reread the preceding code to see if we check this everywhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to make it the most obvious tbh.

But what I was really aiming for was:

  • User signal takes precedence over everything else.
  • If not user specifed resolve attention (prefx + attention) as simply as possible:
    • When some conditions are met, we use non defaults.
    • If nothing else happens we use the defaults (so defaults are more easily seen imho)

I wasn't super happy with the design but that's the reason for the big ugly block on top.This is where I hide 'non defauls resolutions'

So the spirit of my comment is just to try and keep the same rational, have the messy thing only in 1 place. (With the hidden agenda that this mess should disappear, whenever flashinfer supports all the features we would hope it does)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the additional messy stuff into the messy block, so there are now two unwrap_ors again.

kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to send both to every implementation and let V2/flashinfer use the cache directly with block tables, and let V1 and other backends use the raw values.

Let's keep this for now, but maybe food for thought if this logic complexifies.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that crossed my mind as well, certainly worth considering in the future.

@danieldk danieldk requested a review from Narsil September 27, 2024 12:35
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@danieldk danieldk merged commit 5b6b74e into main Sep 27, 2024
12 checks passed
@danieldk danieldk deleted the bugfix/pre-cc-8 branch September 27, 2024 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants