Skip to content

[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce #109158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
///
/// The patterns populated by this function will ignore ops with the
/// `cluster_size` attribute.
/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
void populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);

/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);

/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
Expand Down
37 changes: 32 additions & 5 deletions mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,21 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
struct ScalarSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getClusterSize().has_value() != matchClustered) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("op is {0}clustered but pattern is configured to "
"only match {1}clustered ops",
matchClustered ? "non-" : "",
matchClustered ? "" : "non-"));
}

auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
Expand Down Expand Up @@ -262,19 +270,28 @@ struct ScalarSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
bool matchClustered = false;
};

/// Lowers vector gpu subgroup reductions to a series of shuffles.
struct VectorSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}
shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
if (op.getClusterSize().has_value() != matchClustered) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("op is {0}clustered but pattern is configured to "
"only match {1}clustered ops",
matchClustered ? "non-" : "",
matchClustered ? "" : "non-"));
}

auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
Expand Down Expand Up @@ -343,6 +360,7 @@ struct VectorSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
bool matchClustered = false;
};
} // namespace

Expand All @@ -358,5 +376,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
patterns.getContext(), subgroupSize, shuffleBitwidth,
/*matchClustered=*/false, benefit);
}

void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
patterns.getContext(), subgroupSize, shuffleBitwidth,
/*matchClustered=*/true, benefit);
}
5 changes: 4 additions & 1 deletion mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
populateGpuBreakDownSubgroupReducePatterns(patterns,
/*maxShuffleBitwidth=*/32,
PatternBenefit(2));
if (expandToShuffles)
if (expandToShuffles) {
populateGpuLowerSubgroupReduceToShufflePatterns(
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
}

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Expand Down
Loading