-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Do not merge blocks during canonicalization by default #95057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesThis is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors. Previous behavior can be restored by passing: {test-convergence region-simplify=aggressive} to the canonicalize pass. Full diff: https://github.com/llvm/llvm-project/pull/95057.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..eaff85804f6b3 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
ExistingOps
};
+enum class GreedySimplifyRegionLevel {
+ /// Disable region control-flow simplification.
+ Disabled,
+ /// Run the normal simplification (e.g. dead args elimination).
+ Normal,
+ /// Run extra simplificiations (e.g. block merging), these can be
+ /// more costly or have some tradeoffs associated.
+ Aggressive
+};
+
/// This class allows control over how the GreedyPatternRewriteDriver works.
class GreedyRewriteConfig {
public:
@@ -45,7 +55,8 @@ class GreedyRewriteConfig {
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
- bool enableRegionSimplification = true;
+ GreedySimplifyRegionLevel enableRegionSimplification =
+ GreedySimplifyRegionLevel::Aggressive;
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 58bd61b2ae8b8..8e4a43c3f2458 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -15,6 +15,7 @@
#define MLIR_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87c63f27..000d9f697618e 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
- Option<"enableRegionSimplification", "region-simplify", "bool",
- /*default=*/"true",
- "Perform control flow optimizations to the region tree">,
+ Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
+ /*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
+ "Perform control flow optimizations to the region tree",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
+ "Don't run any control-flow simplification."),
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
+ "Perform simple control-flow simplifications (e.g. dead args elimination)."),
+ clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
+ "Perform aggressive control-flow simplification (e.g. block merging).")
+ )}]>,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 192ff71384059..86b22839f6335 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise. The provided rewriter is
/// used to notify callers of operation and block deletion.
+/// Structurally similar blocks will be merged if the `mergeBlock` argument is
+/// true. Note this can lead to merged blocks with extra arguments.
LogicalResult simplifyRegions(RewriterBase &rewriter,
- MutableArrayRef<Region> regions);
+ MutableArrayRef<Region> regions,
+ bool mergeBlocks = true);
/// Erase the unreachable blocks within the provided regions. Returns success
/// if any blocks were erased, failure otherwise.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..d22b3d3672425 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -871,8 +871,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
- if (config.enableRegionSimplification)
- continueRewrites |= succeeded(simplifyRegions(*this, region));
+ if (config.enableRegionSimplification !=
+ GreedySimplifyRegionLevel::Disabled) {
+ continueRewrites |= succeeded(simplifyRegions(
+ *this, region,
+ /*mergeBlocks=*/config.enableRegionSimplification ==
+ GreedySimplifyRegionLevel::Aggressive));
+ }
},
{®ion}, iteration);
} while (continueRewrites);
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e25867b527b71..a1bebc4809c45 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -828,11 +828,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise.
LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
- MutableArrayRef<Region> regions) {
+ MutableArrayRef<Region> regions,
+ bool mergeBlocks) {
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
- bool mergedIdenticalBlocks =
- succeeded(mergeIdenticalBlocks(rewriter, regions));
+ bool mergedIdenticalBlocks = false;
+ if (mergeBlocks)
+ mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 977d31a6bfe54..d07389d6822ce 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s
//===----------------------------------------------------------------------===//
// spirv.AccessChain
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58dbaa5b96..bf3ab2dae2ff8 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,8 +14,8 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
- // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
- pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
+ // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
+ pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
disable_threading: true
}
}
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index bf44973ab646c..122bfcca66a63 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
// Check the simple case of single operation blocks with a return.
@@ -275,3 +275,19 @@ func.func @mismatch_dominance() -> i32 {
^bb4(%3: i32):
return %3 : i32
}
+
+
+// CHECK-LABEL: func @dead_dealloc_fold_multi_use
+func.func @dead_dealloc_fold_multi_use(%cond : i1) {
+ // CHECK-NEXT: return
+ %a = memref.alloc() : memref<4xf32>
+ cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ memref.dealloc %a: memref<4xf32>
+ return
+
+^bb2:
+ memref.dealloc %a: memref<4xf32>
+ return
+}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 3048a7fed636b..ac034d567a26a 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s
// Test case: Simple case of deleting a dead pure op.
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index d2c2c12d32389..6927189fc626f 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {
// CHECK-LABEL: func @dead_dealloc_fold_multi_use
func.func @dead_dealloc_fold_multi_use(%cond : i1) {
- // CHECK-NEXT: return
+ // CHECK-NOT: alloc
%a = memref.alloc() : memref<4xf32>
+ // CHECK: cond_br
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
+ // CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
+ // CHECK: return
return
^bb2:
+ // CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
+ // CHECK: return
return
}
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095ed7e8cf..0fc822b0a23ae 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
func.func @remove_op_with_inner_ops_pattern() {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I noticed the explosion behaviour in #63230
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Minor comments.
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness { | |||
ExistingOps | |||
}; | |||
|
|||
enum class GreedySimplifyRegionLevel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc comment for this.
bool mergedIdenticalBlocks = false; | ||
if (mergeBlocks) | ||
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be written in one line:
bool mergedIdenticalBlocks = mergeBlocks && succeeded(...);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like the short circuit with side effects: this would obscure the control flow IMO.
@@ -275,3 +275,19 @@ func.func @mismatch_dominance() -> i32 { | |||
^bb4(%3: i32): | |||
return %3 : i32 | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop extra line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, I had one ready to introduce a numBlocksArgs
threshold. I am not sure which one is better, but my idea was that reducing the number of blocks, up to a point, was something useful (especially on GPUs where branches are so bad)
What is the problem you're seeing with the number of block arguments? (in terms of code generation they are lowered to "nothing", basically just a register constraint, worst case a register copy) |
The problem is that blocks with >32k arguments are very slow to manage and hang the compiler (with 32k it takes few minutes, but the original IR I had did never stop compiling). This PR is fixing this problem, but I thought that partially enabling the merger would make a better code-gen (especially on GPUs where branching is so bad) |
It'd be nice to check if we have some obvious issues with our block argument scaling! Do you have some IR reproducer you could share? |
Yes, I am attaching the IRs I used to investigate this issue to this comment. If you do:
The compiler will hang for about 4 seconds and then complete (this is the 32k parameters case I was mentioning). But if you run:
The compiler will hang for a long time (I never got it to finish). If you pass |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors. Previous behavior can be restored by passing: {test-convergence region-simplify=aggressive} to the canonicalize pass.
Hi @joker-eph , should we merge this in? |
I was hoping to hear from @jpienaar actually :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me. The aggressive configuration could get a param later (if needed) but almost feels like it could be a separate pass too. Not a blocker here though.
This enables the current state while also allowing folks (probably majority but we'll probably only find out in post if I'm wrong :-)) to avoid the cost.
Upgrading LLVM repo again, because we need a feature that has been recently submitted in llvm/llvm-project#95057 Changes made: - MathExtras has been merged with its LLVM version. So I had to replace mlir::ceilDiv with llvm:divideCeilSigned
Upgrading LLVM repo again, because we need a feature that has been recently submitted in llvm/llvm-project#95057 Changes made: - MathExtras has been merged with its LLVM version. So I had to replace mlir::ceilDiv with llvm:divideCeilSigned
Upgrading LLVM repo again, because we need a feature that has been recently submitted in llvm/llvm-project#95057 Changes made: - `MathExtras` has been merged with its LLVM version. So I had to replace `mlir::ceilDiv` with `llvm:divideCeilSigned`
…4147) Upgrading LLVM repo again, because we need a feature that has been recently submitted in llvm/llvm-project#95057 Changes made: - `MathExtras` has been merged with its LLVM version. So I had to replace `mlir::ceilDiv` with `llvm:divideCeilSigned`
This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors.
Previous behavior can be restored by passing:
{test-convergence region-simplify=aggressive}
to the canonicalize pass.