Skip to content

[mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce #109158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

andfau-amd
Copy link
Contributor

@andfau-amd andfau-amd commented Sep 18, 2024

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes controlling the lowering much more straightforward.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

Changes

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes makes controlling the lowering much more straightforward.


Full diff: https://github.com/llvm/llvm-project/pull/109158.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+10)
  • (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+38-5)
  • (modified) mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (+4-1)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 67baa8777a6fcc..8eb711962583da 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
 /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
 /// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
 /// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+///
+/// The patterns populated by this function will ignore ops with the
+/// `cluster_size` attribute.
+/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
 void populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
 
+/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
+/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
+void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b166f1cd469a4d..56e53a806843ed 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
 struct ScalarSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 
 /// Lowers vector gpu subgroup reductions to a series of shuffles.
 struct VectorSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 } // namespace
 
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth, PatternBenefit benefit) {
   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
-      patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/false, benefit);
+}
+
+void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth, PatternBenefit benefit) {
+  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/true, benefit);
 }
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 99a914506b011a..74d057c0b7b6cb 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
     populateGpuBreakDownSubgroupReducePatterns(patterns,
                                                /*maxShuffleBitwidth=*/32,
                                                PatternBenefit(2));
-    if (expandToShuffles)
+    if (expandToShuffles) {
       populateGpuLowerSubgroupReduceToShufflePatterns(
           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+      populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+          patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+    }
 
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }

@andfau-amd andfau-amd force-pushed the 18142-clustered-disjoint-populate-functions branch from 3e1a9fe to a05036b Compare September 18, 2024 15:19
@kuhar kuhar self-requested a review September 18, 2024 16:35
@andfau-amd andfau-amd force-pushed the 18142-clustered-disjoint-populate-functions branch from a05036b to 3e66ebd Compare September 18, 2024 16:50
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns()
function also cover the new "clustered" subgroup reductions is proving
to be inconvenient, because certain backends may have more specific
lowerings that only cover the non-clustered type, and this creates pass
ordering constraints. This commit removes coverage of clustered
reductions from this function in favour of a new separate function,
which makes controlling the lowering much more straightforward.
@andfau-amd andfau-amd force-pushed the 18142-clustered-disjoint-populate-functions branch from 3e66ebd to 309408f Compare September 18, 2024 17:07
@kuhar kuhar merged commit a800ffa into llvm:main Sep 18, 2024
8 checks passed
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…lvm#109158)

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns()
function also cover the new "clustered" subgroup reductions is proving
to be inconvenient, because certain backends may have more specific
lowerings that only cover the non-clustered type, and this creates pass
ordering constraints. This commit removes coverage of clustered
reductions from this function in favour of a new separate function,
which makes controlling the lowering much more straightforward.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants