Skip to content

feat: Holistic persistent kernel template with global scheduler #1026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 41 commits into from

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Apr 20, 2025

Follow up of #858 , #967 , this PR implements the persistent kernel template that supports sequential execution of multiple kernels (e.g. one wave for prefill attention, one wave for decode attention and one wave for attention reduction) in a single kernel, with a globla scheduler for load-balancing:

image

POD-Attention can be implemented as different scheduler implementation within this framework.
This PR should also resolve the issue mentioned in #1022

Co-authored-by: Yilong Zhao [email protected]

@Edenzzzz
Copy link
Contributor

Thanks for your attention. I guess the issue with POD is when each prefill/decode seq has varying lengths, there's significant wave quantization between blocks? I might try to adapt it to your template.
Also I wonder if you have any hints on debugging the illegal memory access issue? Sometimes I also see operation not supported on global/shared address space

@AKKamath
Copy link
Contributor

Sorry, I've been a bit busy recently. I'll try taking a look at your illegal access issue. It's what's mentioned in #967 right?

@Edenzzzz
Copy link
Contributor

@AKKamath Hi, it's #1022, and I will try to upstream a cleaner reproduction.

Apart from the illegal access, I'm not sure if the slowdown with irregular decode & prefill seqlens is due to global sync in each wave--I don't any global sync points blocking the next CTA apart from the final merge states

Comment on lines +1149 to +1152
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop();
int actual_len = std::min(remaining_len, kv_len_limit);
cluster_cost_heap.insert(
{cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that with load balancing across tile sizes, persistent kernel achieves the same goal as POD Attention (and even more balanced workload because POD Attention currently calls plan() to balance workload for decode?)
It should also have lower CTA launch and quantization overheads?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though if we can carefully schedule POD Attn to run two CTAs per SM concurrently, each having a different tile size, we could increase tensor core util

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the goal of POD-Attention is to overlap compute-bound (prefill) and IO-bound (decode) by concurrent execution of two kind of workload within a SM (two CTA per SM), which we didn't explore in this PR, but I suppose it's feasible by carefully design scheduler.

BlockPersistentRunner1::Run(params_1, &smem_storage_1);
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1);

__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this sync be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

@yzh119
Copy link
Collaborator Author

yzh119 commented Jun 12, 2025

Moved development to #1137

@yzh119 yzh119 closed this Jun 12, 2025
yzh119 added a commit that referenced this pull request Jun 12, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description
Follow up of #858,
#967, and
#1026, this PR aims to
provide an efficient and unified API for processing prefill and decode
requests within a single kernel launch. Key features include:
1. Single CUDA graph capture for all batch sizes and sequence lengths.
Prior to this PR, FA2 template is implemented with a non-persistent
kernel way, which dispatches `padded_batch_sizes` CTA and uses static
information (ref:
https://github.com/flashinfer-ai/flashinfer/blob/f484fd3c7f09a1d0afb75d779872b9762a35e445/include/flashinfer/attention/scheduler.cuh#L527).
This necessitates a specialized CUDA graph for each batch with different
seqlens and batch sizes, to maximize throughput. Furthermore, prefill
and decode are executed by different kernel launches, increasing the
number of CUDA graphs by combination. This PR implements a
persistent-style kernel, which enables a single CUDA graph to capture
work for all seqlens and batch sizes.
2. Dynamic specialization for prefill and decode. Implemented as a
persistent kernel, prefill and decode requests are dynamically executed
by an efficient kernel template with suitable hyperparameters. For
example, decode requests with `qo_len=1` are processed by
`CTA_TILE_Q=16` while prefill requests with `qo_len>=128` are processed
by `CTA_TILE_Q=128`.

## Perf Benchmarks:
The benchmark script is at `benchmarks/bench_batch_attention.py` and was
tested with Qwen-2.5-7B configurations and a single H200. Visualization:
<img width="594" alt="image"
src="https://github.com/user-attachments/assets/735aca14-387d-4013-b3f4-e199b6cff5f3"
/>
1. 30% bandwidth boost in hybrid scenarios
2. slightly worse perf at pure workloads, which may be caused by the
reduction overhead

## Unit Tests:
Unit tests can be located at `tests/bench_batch_attention.py`.
<img width="1527" alt="image"
src="https://github.com/user-attachments/assets/fff06c6d-c121-497c-9f62-039653149a4d"
/>

## Future works:
1. Add profiler to analyze perf bottleneck
4. Optimize the reduction kernel schedule
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues
#1022

Advised by @yzh119. CC @AKKamath @Edenzzzz 
<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Co-authored-by: yzh119 <[email protected]>
Co-authored-by: happierpig <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants