-
Notifications
You must be signed in to change notification settings - Fork 327
feat: Holistic persistent kernel template with global scheduler #1026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Update attention.py
Update bench_mixed_attention.py
Add tool kernel for sspec
Thanks for your attention. I guess the issue with POD is when each prefill/decode seq has varying lengths, there's significant wave quantization between blocks? I might try to adapt it to your template. |
Sorry, I've been a bit busy recently. I'll try taking a look at your illegal access issue. It's what's mentioned in #967 right? |
@AKKamath Hi, it's #1022, and I will try to upstream a cleaner reproduction. Apart from the illegal access, I'm not sure if the slowdown with irregular decode & prefill seqlens is due to global sync in each wave--I don't any global sync points blocking the next CTA apart from the final merge states |
Profiler for persistent kernels
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop(); | ||
int actual_len = std::min(remaining_len, kv_len_limit); | ||
cluster_cost_heap.insert( | ||
{cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that with load balancing across tile sizes, persistent kernel achieves the same goal as POD Attention (and even more balanced workload because POD Attention currently calls plan()
to balance workload for decode?)
It should also have lower CTA launch and quantization overheads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though if we can carefully schedule POD Attn to run two CTAs per SM concurrently, each having a different tile size, we could increase tensor core util
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the goal of POD-Attention is to overlap compute-bound (prefill) and IO-bound (decode) by concurrent execution of two kind of workload within a SM (two CTA per SM), which we didn't explore in this PR, but I suppose it's feasible by carefully design scheduler.
BlockPersistentRunner1::Run(params_1, &smem_storage_1); | ||
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1); | ||
|
||
__syncthreads(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this sync be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes!
Moved development to #1137 |
<!-- .github/pull_request_template.md --> ## 📌 Description Follow up of #858, #967, and #1026, this PR aims to provide an efficient and unified API for processing prefill and decode requests within a single kernel launch. Key features include: 1. Single CUDA graph capture for all batch sizes and sequence lengths. Prior to this PR, FA2 template is implemented with a non-persistent kernel way, which dispatches `padded_batch_sizes` CTA and uses static information (ref: https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527). This necessitates a specialized CUDA graph for each batch with different seqlens and batch sizes, to maximize throughput. Furthermore, prefill and decode are executed by different kernel launches, increasing the number of CUDA graphs by combination. This PR implements a persistent-style kernel, which enables a single CUDA graph to capture work for all seqlens and batch sizes. 2. Dynamic specialization for prefill and decode. Implemented as a persistent kernel, prefill and decode requests are dynamically executed by an efficient kernel template with suitable hyperparameters. For example, decode requests with `qo_len=1` are processed by `CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed by `CTA_TILE_Q=128`. ## Perf Benchmarks: The benchmark script is at `benchmarks/bench_batch_attention.py` and was tested with Qwen-2.5-7B configurations and a single H200. Visualization: <img width="594" alt="image" src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3" /> 1. 30% bandwidth boost in hybrid scenarios 2. slightly worse perf at pure workloads, which may be caused by the reduction overhead ## Unit Tests: Unit tests can be located at `tests/bench_batch_attention.py`. <img width="1527" alt="image" src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d" /> ## Future works: 1. Add profiler to analyze perf bottleneck 4. Optimize the reduction kernel schedule <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #1022 Advised by @yzh119. CC @AKKamath @Edenzzzz <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Co-authored-by: yzh119 <[email protected]> Co-authored-by: happierpig <[email protected]>
Follow up of #858 , #967 , this PR implements the persistent kernel template that supports sequential execution of multiple kernels (e.g. one wave for prefill attention, one wave for decode attention and one wave for attention reduction) in a single kernel, with a globla scheduler for load-balancing:
POD-Attention can be implemented as different scheduler implementation within this framework.
This PR should also resolve the issue mentioned in #1022
Co-authored-by: Yilong Zhao [email protected]