Skip to content

Commit 309408f

Browse files
committed
[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.
1 parent b334ca6 commit 309408f

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
7373
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
7474
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
7575
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
76+
///
77+
/// The patterns populated by this function will ignore ops with the
78+
/// `cluster_size` attribute.
79+
/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
7680
void populateGpuLowerSubgroupReduceToShufflePatterns(
7781
RewritePatternSet &patterns, unsigned subgroupSize,
7882
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
7983

84+
/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
85+
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
86+
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
87+
RewritePatternSet &patterns, unsigned subgroupSize,
88+
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
89+
8090
/// Collect all patterns to rewrite ops within the GPU dialect.
8191
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
8292
populateGpuAllReducePatterns(patterns);

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,21 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210210
struct ScalarSubgroupReduceToShuffles final
211211
: OpRewritePattern<gpu::SubgroupReduceOp> {
212212
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213-
unsigned shuffleBitwidth,
213+
unsigned shuffleBitwidth, bool matchClustered,
214214
PatternBenefit benefit)
215215
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216-
shuffleBitwidth(shuffleBitwidth) {}
216+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217217

218218
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219219
PatternRewriter &rewriter) const override {
220+
if (op.getClusterSize().has_value() != matchClustered) {
221+
return rewriter.notifyMatchFailure(
222+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
223+
"only match {1}clustered ops",
224+
matchClustered ? "non-" : "",
225+
matchClustered ? "" : "non-"));
226+
}
227+
220228
auto ci = getAndValidateClusterInfo(op, subgroupSize);
221229
if (failed(ci))
222230
return failure();
@@ -262,19 +270,28 @@ struct ScalarSubgroupReduceToShuffles final
262270
private:
263271
unsigned subgroupSize = 0;
264272
unsigned shuffleBitwidth = 0;
273+
bool matchClustered = false;
265274
};
266275

267276
/// Lowers vector gpu subgroup reductions to a series of shuffles.
268277
struct VectorSubgroupReduceToShuffles final
269278
: OpRewritePattern<gpu::SubgroupReduceOp> {
270279
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
271-
unsigned shuffleBitwidth,
280+
unsigned shuffleBitwidth, bool matchClustered,
272281
PatternBenefit benefit)
273282
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274-
shuffleBitwidth(shuffleBitwidth) {}
283+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
275284

276285
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
277286
PatternRewriter &rewriter) const override {
287+
if (op.getClusterSize().has_value() != matchClustered) {
288+
return rewriter.notifyMatchFailure(
289+
op, llvm::formatv("op is {0}clustered but pattern is configured to "
290+
"only match {1}clustered ops",
291+
matchClustered ? "non-" : "",
292+
matchClustered ? "" : "non-"));
293+
}
294+
278295
auto ci = getAndValidateClusterInfo(op, subgroupSize);
279296
if (failed(ci))
280297
return failure();
@@ -343,6 +360,7 @@ struct VectorSubgroupReduceToShuffles final
343360
private:
344361
unsigned subgroupSize = 0;
345362
unsigned shuffleBitwidth = 0;
363+
bool matchClustered = false;
346364
};
347365
} // namespace
348366

@@ -358,5 +376,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358376
RewritePatternSet &patterns, unsigned subgroupSize,
359377
unsigned shuffleBitwidth, PatternBenefit benefit) {
360378
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361-
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
379+
patterns.getContext(), subgroupSize, shuffleBitwidth,
380+
/*matchClustered=*/false, benefit);
381+
}
382+
383+
void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
384+
RewritePatternSet &patterns, unsigned subgroupSize,
385+
unsigned shuffleBitwidth, PatternBenefit benefit) {
386+
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387+
patterns.getContext(), subgroupSize, shuffleBitwidth,
388+
/*matchClustered=*/true, benefit);
362389
}

mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
7878
populateGpuBreakDownSubgroupReducePatterns(patterns,
7979
/*maxShuffleBitwidth=*/32,
8080
PatternBenefit(2));
81-
if (expandToShuffles)
81+
if (expandToShuffles) {
8282
populateGpuLowerSubgroupReduceToShufflePatterns(
8383
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84+
populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
85+
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
86+
}
8487

8588
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8689
}

0 commit comments

Comments
 (0)