Skip to content

[mlir] Do not merge blocks during canonicalization by default #95057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024

Conversation

joker-eph
Copy link
Collaborator

This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors.

Previous behavior can be restored by passing:

{test-convergence region-simplify=aggressive}

to the canonicalize pass.

@joker-eph joker-eph requested review from jpienaar and Mogball June 10, 2024 23:11
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:spirv mlir labels Jun 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors.

Previous behavior can be restored by passing:

{test-convergence region-simplify=aggressive}

to the canonicalize pass.


Full diff: https://github.com/llvm/llvm-project/pull/95057.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h (+12-1)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+11-3)
  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+4-1)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+5-3)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Pass/run-reproducer.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+17-1)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+1-1)
  • (modified) mlir/test/Transforms/canonicalize.mlir (+6-1)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (+1-1)
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..eaff85804f6b3 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
   ExistingOps
 };
 
+enum class GreedySimplifyRegionLevel {
+  /// Disable region control-flow simplification.
+  Disabled,
+  /// Run the normal simplification (e.g. dead args elimination).
+  Normal,
+  /// Run extra simplificiations (e.g. block merging), these can be
+  /// more costly or have some tradeoffs associated.
+  Aggressive
+};
+
 /// This class allows control over how the GreedyPatternRewriteDriver works.
 class GreedyRewriteConfig {
 public:
@@ -45,7 +55,8 @@ class GreedyRewriteConfig {
   /// patterns.
   ///
   /// Note: Only applicable when simplifying entire regions.
-  bool enableRegionSimplification = true;
+  GreedySimplifyRegionLevel enableRegionSimplification =
+      GreedySimplifyRegionLevel::Aggressive;
 
   /// This specifies the maximum number of times the rewriter will iterate
   /// between applying patterns and simplifying regions. Use `kNoLimit` to
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..8e4a43c3f2458 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -15,6 +15,7 @@
 #define MLIR_TRANSFORMS_PASSES_H
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/ViewOpGraph.h"
 #include "llvm/Support/Debug.h"
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..000d9f697618e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
     Option<"topDownProcessingEnabled", "top-down", "bool",
            /*default=*/"true",
            "Seed the worklist in general top-down order">,
-    Option<"enableRegionSimplification", "region-simplify", "bool",
-           /*default=*/"true",
-           "Perform control flow optimizations to the region tree">,
+    Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+           /*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
+           "Perform control flow optimizations to the region tree",
+             [{::llvm::cl::values(
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
+                "Don't run any control-flow simplification."),
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
+                "Perform simple control-flow simplifications (e.g. dead args elimination)."),
+               clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
+                "Perform aggressive control-flow simplification (e.g. block merging).")
+              )}]>,
     Option<"maxIterations", "max-iterations", "int64_t",
            /*default=*/"10",
            "Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 192ff71384059..86b22839f6335 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise. The provided rewriter is
 /// used to notify callers of operation and block deletion.
+/// Structurally similar blocks will be merged if the `mergeBlock` argument is
+/// true. Note this can lead to merged blocks with extra arguments.
 LogicalResult simplifyRegions(RewriterBase &rewriter,
-                              MutableArrayRef<Region> regions);
+                              MutableArrayRef<Region> regions,
+                              bool mergeBlocks = true);
 
 /// Erase the unreachable blocks within the provided regions. Returns success
 /// if any blocks were erased, failure otherwise.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..d22b3d3672425 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -871,8 +871,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
-          if (config.enableRegionSimplification)
-            continueRewrites |= succeeded(simplifyRegions(*this, region));
+          if (config.enableRegionSimplification !=
+              GreedySimplifyRegionLevel::Disabled) {
+            continueRewrites |= succeeded(simplifyRegions(
+                *this, region,
+                /*mergeBlocks=*/config.enableRegionSimplification ==
+                    GreedySimplifyRegionLevel::Aggressive));
+          }
         },
         {&region}, iteration);
   } while (continueRewrites);
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e25867b527b71..a1bebc4809c45 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -828,11 +828,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
 /// elimination, as well as some other DCE. This function returns success if any
 /// of the regions were simplified, failure otherwise.
 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
-                                    MutableArrayRef<Region> regions) {
+                                    MutableArrayRef<Region> regions,
+                                    bool mergeBlocks) {
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
-  bool mergedIdenticalBlocks =
-      succeeded(mergeIdenticalBlocks(rewriter, regions));
+  bool mergedIdenticalBlocks = false;
+  if (mergeBlocks)
+    mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 977d31a6bfe54..d07389d6822ce 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.AccessChain
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96..bf3ab2dae2ff8 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,8 +14,8 @@ func.func @bar() {
   external_resources: {
     mlir_reproducer: {
       verify_each: true,
-      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
-      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
+      // CHECK:  builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
+      pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
       disable_threading: true
     }
   }
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index bf44973ab646c..122bfcca66a63 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
 
 // Check the simple case of single operation blocks with a return.
 
@@ -275,3 +275,19 @@ func.func @mismatch_dominance() -> i32 {
 ^bb4(%3: i32):
   return %3 : i32
 }
+
+
+// CHECK-LABEL: func @dead_dealloc_fold_multi_use
+func.func @dead_dealloc_fold_multi_use(%cond : i1) {
+  // CHECK-NEXT: return
+  %a = memref.alloc() : memref<4xf32>
+  cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  memref.dealloc %a: memref<4xf32>
+  return
+
+^bb2:
+  memref.dealloc %a: memref<4xf32>
+  return
+}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 3048a7fed636b..ac034d567a26a 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s
 
 // Test case: Simple case of deleting a dead pure op.
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index d2c2c12d32389..6927189fc626f 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {
 
 // CHECK-LABEL: func @dead_dealloc_fold_multi_use
 func.func @dead_dealloc_fold_multi_use(%cond : i1) {
-  // CHECK-NEXT: return
+  // CHECK-NOT: alloc
   %a = memref.alloc() : memref<4xf32>
+  // CHECK: cond_br
   cf.cond_br %cond, ^bb1, ^bb2
 
 ^bb1:
+  // CHECK-NOT: alloc
   memref.dealloc %a: memref<4xf32>
+  // CHECK: return
   return
 
 ^bb2:
+  // CHECK-NOT: alloc
   memref.dealloc %a: memref<4xf32>
+  // CHECK: return
   return
 }
 
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf..0fc822b0a23ae 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func.func @remove_op_with_inner_ops_pattern() {

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I noticed the explosion behaviour in #63230

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Minor comments.

@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
ExistingOps
};

enum class GreedySimplifyRegionLevel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment for this.

Comment on lines +834 to +836
bool mergedIdenticalBlocks = false;
if (mergeBlocks)
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be written in one line:

bool mergedIdenticalBlocks = mergeBlocks && succeeded(...);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the short circuit with side effects: this would obscure the control flow IMO.

@@ -275,3 +275,19 @@ func.func @mismatch_dominance() -> i32 {
^bb4(%3: i32):
return %3 : i32
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop extra line.

Copy link
Contributor

@giuseros giuseros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, I had one ready to introduce a numBlocksArgs threshold. I am not sure which one is better, but my idea was that reducing the number of blocks, up to a point, was something useful (especially on GPUs where branches are so bad)

@joker-eph
Copy link
Collaborator Author

I had one ready to introduce a numBlocksArgs threshold. I am not sure which one is better, but my idea was that reducing the number of blocks, up to a point, was something useful (especially on GPUs where branches are so bad)

What is the problem you're seeing with the number of block arguments? (in terms of code generation they are lowered to "nothing", basically just a register constraint, worst case a register copy)

@giuseros
Copy link
Contributor

giuseros commented Jun 11, 2024

What is the problem you're seeing with the number of block arguments? (in terms of code generation they are lowered to "nothing", basically just a register constraint, worst case a register copy)

The problem is that blocks with >32k arguments are very slow to manage and hang the compiler (with 32k it takes few minutes, but the original IR I had did never stop compiling).

This PR is fixing this problem, but I thought that partially enabling the merger would make a better code-gen (especially on GPUs where branching is so bad)

@joker-eph
Copy link
Collaborator Author

The problem is that blocks with >32k arguments are very slow to manage and hang the compiler (with 32k it takes few minutes, but the original IR I had did never stop compiling).

It'd be nice to check if we have some obvious issues with our block argument scaling! Do you have some IR reproducer you could share?

@giuseros
Copy link
Contributor

giuseros commented Jun 12, 2024

It'd be nice to check if we have some obvious issues with our block argument scaling! Do you have some IR reproducer you could share?

Yes, I am attaching the IRs I used to investigate this issue to this comment.

If you do:

$ ./bin/mlir-opt --canonicalize disable_simplify.mlir

The compiler will hang for about 4 seconds and then complete (this is the 32k parameters case I was mentioning). But if you run:

$ ./bin/mlir-opt --canonicalize disable_simplify_hang.mlir

The compiler will hang for a long time (I never got it to finish). If you pass region-simplify=false it will finish almost immediately.

disable_simplify.zip

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jun 12, 2024
Copy link

github-actions bot commented Jun 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

This is a heavy process, and it can trigger a massive explosion
in adding block arguments. While potentially reducing the code
size, the resulting merged blocks with arguments are hiding some
of the def-use chain and can even hinder some further
analyses/optimizations: a merge block does not have it's own
path-sensitive context, instead the context is merged from all the
predecessors.

Previous behavior can be restored by passing:

  {test-convergence region-simplify=aggressive}

to the canonicalize pass.
@giuseros
Copy link
Contributor

Hi @joker-eph , should we merge this in?

@joker-eph
Copy link
Collaborator Author

I was hoping to hear from @jpienaar actually :)

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. The aggressive configuration could get a param later (if needed) but almost feels like it could be a separate pass too. Not a blocker here though.

This enables the current state while also allowing folks (probably majority but we'll probably only find out in post if I'm wrong :-)) to avoid the cost.

@joker-eph joker-eph merged commit a506279 into llvm:main Jun 14, 2024
7 checks passed
@joker-eph joker-eph deleted the block-merge branch June 14, 2024 20:39
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Jun 17, 2024
Upgrading LLVM repo again, because we need a feature that has been
recently submitted in llvm/llvm-project#95057

Changes made:
- MathExtras has been merged with its LLVM version. So I had to replace
mlir::ceilDiv with llvm:divideCeilSigned
chsigg pushed a commit to triton-lang/triton that referenced this pull request Jun 18, 2024
Upgrading LLVM repo again, because we need a feature that has been
recently submitted in llvm/llvm-project#95057

Changes made:
- MathExtras has been merged with its LLVM version. So I had to replace
mlir::ceilDiv with llvm:divideCeilSigned
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Jun 19, 2024
Upgrading LLVM repo again, because we need a feature that has been
recently submitted in llvm/llvm-project#95057

Changes made:
- `MathExtras` has been merged with its LLVM version. So I had to
replace `mlir::ceilDiv` with `llvm:divideCeilSigned`
@youngar youngar mentioned this pull request Jun 24, 2024
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…4147)

Upgrading LLVM repo again, because we need a feature that has been
recently submitted in llvm/llvm-project#95057

Changes made:
- `MathExtras` has been merged with its LLVM version. So I had to
replace `mlir::ceilDiv` with `llvm:divideCeilSigned`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants