-
Notifications
You must be signed in to change notification settings - Fork 24.6k
[Inductor] Support tiling reduction dimensions #137243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…to brister/prefer_tiling
@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Just to be safe, this draft PR tests the CI with tiled reductions enabled by default: #144008 |
Tiled reductions by default turned out to break a few things, including Given the weight of this PR, I think it makes sense to merge this as is and handle the missing reduction ranges in a follow up. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…44041) # Issue This PR cleans up an edge case that wasn't handled by #137243. The existing tiling code assumes that `node.get_ranges()` is a reliable source of pointwise and reduction numels. This is true for pointwise kernels, but the situation is more complicated with reductions. Since reductions change the number of elements in a tensor, not all ops within a reduction kernel will have the same number of iterations. For example, `var_mean` fuses pointwise division with the output of reduction sum, and the division lacks the corresponding reduction ranges. # Fix Instead of getting numels from `node.get_ranges()`, explicitly pass the global pointwise and reduction numels to the relevant tiling functions. In `SIMDKernel.complete_partial_tiling`, we solve for the missing numel by diving the global numel by the partial tiling's numel. This ensures all tilings have the correct global numel. Also, in `SIMDKernel.is_compatible`, add the global reduction numel to node ranges that are missing it. For example, `{"x": 8, "r0_": 8}` is compatible with a node of ranges `([8], [])` when we have `reduction_numel=8`. Finally, this PR generalizes some of the existing codegen to handle multiple reduction dims. We already had code to ignore reduction splits for pointwise kernels, but it only worked for 1D reductions. Now it can handle ND. # Test plan This PR parametrizes the existing CI test for `var_mean` to also run with tiled reductions. It also adds a new test checking that `var_mean` generates 2D tilings (with tiled reduction enabled). These new tests would fail on the current main branch. Pull Request resolved: #144041 Approved by: https://github.com/jansel
# Issue #137243 introduced a feature where the ND tiling algorithm analyzes memory dependencies. It iterates over all `Dep`'s of the kernel. However, the analysis is only applicable to `MemoryDep` instances, which are a subclass of `Dep`. In particular, it doesn't work for `StarDep`'s, for the reasons described here: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/simd.py#L1653 # Fix This PR changes the algorithm to only iterate over `MemoryDep` instances. # Testing Parameterized an existing test for `torch.bucketize` to also run with ND tiling. This test emits a node with `StarDep`'s. Without this PR, the compiler would crash on this test case. Pull Request resolved: #144497 Approved by: https://github.com/eellison
Fixes #134277 and #142317.
Sub-PRs containing refactors from this one:
These refactor PRs should land before the main one.
Feature
Note: to minimize risk, multi-dimensional reductions are gated by the flag
config.triton.tile_reductions
, which defaults to False.Instead of having a single reduction dimension called
"r"
, we can now support 2D reductions with"r0_"
and"r1_"
dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D.Here's an example of a 2D persistent reduction kernel:
There are a few main differences between this kernel and what Inductor would generate without this PR.
r
/RBLOCK
dimension, we have two reduction dimensions:r0_
/R0_BLOCK
andr1_
/R1_BLOCK
.rindex
,rnumel
,RBLOCK
, androffset
.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code.tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])
before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood.Here's an example of a looped reduction:
In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code:
r0_base
andr1_base
. For compatibility with existing codegen, these are collapsed to the 1D variantrbase
.tl.advance
line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a newself.pointer_advancements
field of the kernel, which categorizes advancements by dimension.The biggest difficulty in implementing this feature was that we represented tiling with a tuple like
(5,2)
. In the existing codebase, the compiler can infer that the reduction dimension of(5,2)
is2
, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like{"x": 5, "r0_": 2, "r1_": 4}
. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like{"x": 5, "y": 5, "r0_": 2, "r1_": 4}
. (This is not supported today, but we might want to do it eventually.)The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (
"x"
,"y"
) as is. For reduction kernels, we never tile the"x"
dimension, and only tile the reduction dimensions ("r0_"
,"r1_"
). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files.Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for
argmax
andvar_mean
, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing.To address these cases, this PR adds a new feature to the
config.prefer_nd_tiling
option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which would simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible.Test plan
This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like
config.prefer_nd_tiling
andconfig.tile_reductions
, so this really only checks that the PR doesn't break 1D reductions.In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions:
test_2d_reduction_odd_shapes
: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions.test_2d_reduce_no_x_dim
: test 2D reductions with no x dimension.test_2d_welford_reduction
: test 2D welford reductions with block pointers.test_welford_non_block_pointer
: test a 2D welford reduction when block pointer analysis fails.test_reduction_multiple_discontiguous_dims
: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D.test_2d_reduction_multi_kernel
: test multi kernel autotuning on a 2D softmax kernel.test_enable_tiled_reductions
: test thatconfig.triton.tile_reductions
enables/disables this feature.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov