-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce #109158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce #109158
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Andrea Faulds (andfau-amd) ChangesMaking the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes makes controlling the lowering much more straightforward. Full diff: https://github.com/llvm/llvm-project/pull/109158.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 67baa8777a6fcc..8eb711962583da 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+///
+/// The patterns populated by this function will ignore ops with the
+/// `cluster_size` attribute.
+/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
void populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
+/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
+void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b166f1cd469a4d..56e53a806843ed 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
struct ScalarSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
- unsigned shuffleBitwidth,
+ unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
- shuffleBitwidth(shuffleBitwidth) {}
+ shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ if (op.getClusterSize().has_value() != matchClustered) {
+ if (matchClustered)
+ return rewriter.notifyMatchFailure(
+ op, "op is non-clustered but pattern is configured to only match "
+ "clustered ops");
+ else
+ return rewriter.notifyMatchFailure(
+ op, "op is clustered but pattern is configured to only match "
+ "non-clustered ops");
+ }
+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
+ bool matchClustered;
};
/// Lowers vector gpu subgroup reductions to a series of shuffles.
struct VectorSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
- unsigned shuffleBitwidth,
+ unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
- shuffleBitwidth(shuffleBitwidth) {}
+ shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ if (op.getClusterSize().has_value() != matchClustered) {
+ if (matchClustered)
+ return rewriter.notifyMatchFailure(
+ op, "op is non-clustered but pattern is configured to only match "
+ "clustered ops");
+ else
+ return rewriter.notifyMatchFailure(
+ op, "op is clustered but pattern is configured to only match "
+ "non-clustered ops");
+ }
+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
+ bool matchClustered;
};
} // namespace
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
- patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+ patterns.getContext(), subgroupSize, shuffleBitwidth,
+ /*matchClustered=*/false, benefit);
+}
+
+void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
+ patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+ patterns.getContext(), subgroupSize, shuffleBitwidth,
+ /*matchClustered=*/true, benefit);
}
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 99a914506b011a..74d057c0b7b6cb 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
populateGpuBreakDownSubgroupReducePatterns(patterns,
/*maxShuffleBitwidth=*/32,
PatternBenefit(2));
- if (expandToShuffles)
+ if (expandToShuffles) {
populateGpuLowerSubgroupReduceToShufflePatterns(
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+ populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+ }
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
|
3e1a9fe
to
a05036b
Compare
a05036b
to
3e66ebd
Compare
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.
3e66ebd
to
309408f
Compare
…lvm#109158) Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.