-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Improve support for GPUs with capability < 8 #2575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
491c026
to
5bd79b8
Compare
- For models that cannot use flashinfer, use flash-attn v1 + paged attention for models with a compute capability older than 8. - Disable prefix caching when using paged attention. - When using flash-attn v1, pass the key/value, rather than the cache, since v1 cannot use block tables.
5bd79b8
to
8c0f931
Compare
@@ -65,6 +66,7 @@ fn get_config( | |||
} | |||
|
|||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { | |||
let compute_capability = *gpu::COMPUTE_CAPABILITY; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just get_cuda_capability()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to start a slow GIL if it's not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand. This will only evaluate the compute capability once, it's just a cheap Rust lock. If we'd call get_cuda_capability
, we'd call into Python land every time we call get_cuda_capability
(only once now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes yes I understand, but since we're only calling once there's no need, but it's really totally not important.
launcher/src/main.rs
Outdated
let prefix_caching = if attention == "paged" | ||
&& prefix_caching.is_none() | ||
&& compute_capability.is_some() | ||
{ | ||
tracing::info!("Disabling prefix caching because it is not supported with 'flashinfer'"); | ||
"false".to_string() | ||
} else { | ||
prefix_caching.unwrap_or("true".to_string()) | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we move upstairs ? We're not allowed to change anything if the user sets it directly so this line must always be unwraps_or
.
And in the previous code every modification must be protected to check that the thing is not None (there might even be a clearer method for it on Option
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need an unwrap_or
? The first branch of the conditional only fires when prefix_caching
is None
, so the user has set something. Don't know why I added the compute_capability.is_some()
though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll reread the preceding code to see if we check this everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how to make it the most obvious tbh.
But what I was really aiming for was:
- User signal takes precedence over everything else.
- If not user specifed resolve attention (prefx + attention) as simply as possible:
- When some conditions are met, we use non defaults.
- If nothing else happens we use the defaults (so defaults are more easily seen imho)
I wasn't super happy with the design but that's the reason for the big ugly block on top.This is where I hide 'non defauls resolutions'
So the spirit of my comment is just to try and keep the same rational, have the messy thing only in 1 place. (With the hidden agenda that this mess should disappear, whenever flashinfer supports all the features we would hope it does)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the additional messy stuff into the messy block, so there are now two unwrap_or
s again.
kv_cache[0] if SYSTEM != "ipex" else key, | ||
kv_cache[1] if SYSTEM != "ipex" else value, | ||
kv_cache[0] if PREFILL_IN_KV_CACHE else key, | ||
kv_cache[1] if PREFILL_IN_KV_CACHE else value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to send both to every implementation and let V2/flashinfer use the cache directly with block tables, and let V1 and other backends use the raw values.
Let's keep this for now, but maybe food for thought if this logic complexifies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that crossed my mind as well, certainly worth considering in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
Improve support for GPUs with capability < 8
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.