Skip to content

Commit a800ffa

Browse files
authored
[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce (#109158)
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.
1 parent 1e19e1e commit a800ffa

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
7373
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
7474
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
7575
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
76+
///
77+
/// The patterns populated by this function will ignore ops with the
78+
/// `cluster_size` attribute.
79+
/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
7680
void populateGpuLowerSubgroupReduceToShufflePatterns(
7781
RewritePatternSet &patterns, unsigned subgroupSize,
7882
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
7983

84+
/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
85+
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
86+
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
87+
RewritePatternSet &patterns, unsigned subgroupSize,
88+
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
89+
8090
/// Collect all patterns to rewrite ops within the GPU dialect.
8191
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
8292
populateGpuAllReducePatterns(patterns);

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,21 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210210
struct ScalarSubgroupReduceToShuffles final
211211
: OpRewritePattern<gpu::SubgroupReduceOp> {
212212
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213-
unsigned shuffleBitwidth,
213+
unsigned shuffleBitwidth, bool matchClustered,
214214
PatternBenefit benefit)
215215
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216-
shuffleBitwidth(shuffleBitwidth) {}
216+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217217

218218
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219219
PatternRewriter &rewriter) const override {
220+
if (op.getClusterSize().has_value() != matchClustered) {
221+
return rewriter.notifyMatchFailure(
222+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
223+
"only match {1}clustered ops",
224+
matchClustered ? "non-" : "",
225+
matchClustered ? "" : "non-"));
226+
}
227+
220228
auto ci = getAndValidateClusterInfo(op, subgroupSize);
221229
if (failed(ci))
222230
return failure();
@@ -262,19 +270,28 @@ struct ScalarSubgroupReduceToShuffles final
262270
private:
263271
unsigned subgroupSize = 0;
264272
unsigned shuffleBitwidth = 0;
273+
bool matchClustered = false;
265274
};
266275

267276
/// Lowers vector gpu subgroup reductions to a series of shuffles.
268277
struct VectorSubgroupReduceToShuffles final
269278
: OpRewritePattern<gpu::SubgroupReduceOp> {
270279
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
271-
unsigned shuffleBitwidth,
280+
unsigned shuffleBitwidth, bool matchClustered,
272281
PatternBenefit benefit)
273282
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274-
shuffleBitwidth(shuffleBitwidth) {}
283+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
275284

276285
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
277286
PatternRewriter &rewriter) const override {
287+
if (op.getClusterSize().has_value() != matchClustered) {
288+
return rewriter.notifyMatchFailure(
289+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
290+
"only match {1}clustered ops",
291+
matchClustered ? "non-" : "",
292+
matchClustered ? "" : "non-"));
293+
}
294+
278295
auto ci = getAndValidateClusterInfo(op, subgroupSize);
279296
if (failed(ci))
280297
return failure();
@@ -343,6 +360,7 @@ struct VectorSubgroupReduceToShuffles final
343360
private:
344361
unsigned subgroupSize = 0;
345362
unsigned shuffleBitwidth = 0;
363+
bool matchClustered = false;
346364
};
347365
} // namespace
348366

@@ -358,5 +376,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358376
RewritePatternSet &patterns, unsigned subgroupSize,
359377
unsigned shuffleBitwidth, PatternBenefit benefit) {
360378
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361-
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
379+
patterns.getContext(), subgroupSize, shuffleBitwidth,
380+
/*matchClustered=*/false, benefit);
381+
}
382+
383+
void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
384+
RewritePatternSet &patterns, unsigned subgroupSize,
385+
unsigned shuffleBitwidth, PatternBenefit benefit) {
386+
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387+
patterns.getContext(), subgroupSize, shuffleBitwidth,
388+
/*matchClustered=*/true, benefit);
362389
}

mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
7878
populateGpuBreakDownSubgroupReducePatterns(patterns,
7979
/*maxShuffleBitwidth=*/32,
8080
PatternBenefit(2));
81-
if (expandToShuffles)
81+
if (expandToShuffles) {
8282
populateGpuLowerSubgroupReduceToShufflePatterns(
8383
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84+
populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
85+
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
86+
}
8487

8588
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8689
}

0 commit comments

Comments
 (0)