Skip to content

Commit a05036b

Browse files
committed
[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.
1 parent b334ca6 commit a05036b

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
7373
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
7474
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
7575
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
76+
///
77+
/// The patterns populated by this function will ignore ops with the
78+
/// `cluster_size` attribute.
79+
/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
7680
void populateGpuLowerSubgroupReduceToShufflePatterns(
7781
RewritePatternSet &patterns, unsigned subgroupSize,
7882
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
7983

84+
/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
85+
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
86+
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
87+
RewritePatternSet &patterns, unsigned subgroupSize,
88+
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
89+
8090
/// Collect all patterns to rewrite ops within the GPU dialect.
8191
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
8292
populateGpuAllReducePatterns(patterns);

mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
210210
struct ScalarSubgroupReduceToShuffles final
211211
: OpRewritePattern<gpu::SubgroupReduceOp> {
212212
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213-
unsigned shuffleBitwidth,
213+
unsigned shuffleBitwidth, bool matchClustered,
214214
PatternBenefit benefit)
215215
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216-
shuffleBitwidth(shuffleBitwidth) {}
216+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217217

218218
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219219
PatternRewriter &rewriter) const override {
220+
if (op.getClusterSize().has_value() != matchClustered) {
221+
if (matchClustered)
222+
return rewriter.notifyMatchFailure(
223+
op, "op is non-clustered but pattern is configured to only match "
224+
"clustered ops");
225+
else
226+
return rewriter.notifyMatchFailure(
227+
op, "op is clustered but pattern is configured to only match "
228+
"non-clustered ops");
229+
}
230+
220231
auto ci = getAndValidateClusterInfo(op, subgroupSize);
221232
if (failed(ci))
222233
return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
262273
private:
263274
unsigned subgroupSize = 0;
264275
unsigned shuffleBitwidth = 0;
276+
bool matchClustered;
265277
};
266278

267279
/// Lowers vector gpu subgroup reductions to a series of shuffles.
268280
struct VectorSubgroupReduceToShuffles final
269281
: OpRewritePattern<gpu::SubgroupReduceOp> {
270282
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
271-
unsigned shuffleBitwidth,
283+
unsigned shuffleBitwidth, bool matchClustered,
272284
PatternBenefit benefit)
273285
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
274-
shuffleBitwidth(shuffleBitwidth) {}
286+
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
275287

276288
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
277289
PatternRewriter &rewriter) const override {
290+
if (op.getClusterSize().has_value() != matchClustered) {
291+
if (matchClustered)
292+
return rewriter.notifyMatchFailure(
293+
op, "op is non-clustered but pattern is configured to only match "
294+
"clustered ops");
295+
else
296+
return rewriter.notifyMatchFailure(
297+
op, "op is clustered but pattern is configured to only match "
298+
"non-clustered ops");
299+
}
300+
278301
auto ci = getAndValidateClusterInfo(op, subgroupSize);
279302
if (failed(ci))
280303
return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
343366
private:
344367
unsigned subgroupSize = 0;
345368
unsigned shuffleBitwidth = 0;
369+
bool matchClustered;
346370
};
347371
} // namespace
348372

@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
358382
RewritePatternSet &patterns, unsigned subgroupSize,
359383
unsigned shuffleBitwidth, PatternBenefit benefit) {
360384
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
361-
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
385+
patterns.getContext(), subgroupSize, shuffleBitwidth,
386+
/*matchClustered=*/false, benefit);
387+
}
388+
389+
void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
390+
RewritePatternSet &patterns, unsigned subgroupSize,
391+
unsigned shuffleBitwidth, PatternBenefit benefit) {
392+
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
393+
patterns.getContext(), subgroupSize, shuffleBitwidth,
394+
/*matchClustered=*/true, benefit);
362395
}

mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
7878
populateGpuBreakDownSubgroupReducePatterns(patterns,
7979
/*maxShuffleBitwidth=*/32,
8080
PatternBenefit(2));
81-
if (expandToShuffles)
81+
if (expandToShuffles) {
8282
populateGpuLowerSubgroupReduceToShufflePatterns(
8383
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84+
populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
85+
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
86+
}
8487

8588
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
8689
}

0 commit comments

Comments
 (0)