Skip to content

Commit a506279

Browse files
authored
[mlir] Do not merge blocks during canonicalization by default (#95057)
This is a heavy process, and it can trigger a massive explosion in adding block arguments. While potentially reducing the code size, the resulting merged blocks with arguments are hiding some of the def-use chain and can even hinder some further analyses/optimizations: a merge block does not have it's own path-sensitive context, instead the context is merged from all the predecessors. Previous behavior can be restored by passing: {test-convergence region-simplify=aggressive} to the canonicalize pass.
1 parent 41fecca commit a506279

File tree

18 files changed

+78
-24
lines changed

18 files changed

+78
-24
lines changed

flang/include/flang/Tools/CLOptions.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ namespace fir {
144144
static void addCanonicalizerPassWithoutRegionSimplification(
145145
mlir::OpPassManager &pm) {
146146
mlir::GreedyRewriteConfig config;
147-
config.enableRegionSimplification = false;
147+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
148148
pm.addPass(mlir::createCanonicalizerPass(config));
149149
}
150150

@@ -260,7 +260,7 @@ inline void createDefaultFIROptimizerPassPipeline(
260260

261261
// simplify the IR
262262
mlir::GreedyRewriteConfig config;
263-
config.enableRegionSimplification = false;
263+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
264264
pm.addPass(mlir::createCSEPass());
265265
fir::addAVC(pm, pc.OptLevel);
266266
addNestedPassToAllTopLevelOperations(pm, fir::createCharacterConversion);

flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ class InlineElementalsPass
119119

120120
mlir::GreedyRewriteConfig config;
121121
// Prevent the pattern driver from merging blocks.
122-
config.enableRegionSimplification = false;
122+
config.enableRegionSimplification =
123+
mlir::GreedySimplifyRegionLevel::Disabled;
123124

124125
mlir::RewritePatternSet patterns(context);
125126
patterns.insert<InlineElementalConversion>(context);

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ class LowerHLFIRIntrinsics
486486
// Pattern rewriting only requires that the resulting IR is still valid
487487
mlir::GreedyRewriteConfig config;
488488
// Prevent the pattern driver from merging blocks
489-
config.enableRegionSimplification = false;
489+
config.enableRegionSimplification =
490+
mlir::GreedySimplifyRegionLevel::Disabled;
490491

491492
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
492493
module, std::move(patterns), config))) {

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,8 @@ class OptimizedBufferizationPass
10421042

10431043
mlir::GreedyRewriteConfig config;
10441044
// Prevent the pattern driver from merging blocks
1045-
config.enableRegionSimplification = false;
1045+
config.enableRegionSimplification =
1046+
mlir::GreedySimplifyRegionLevel::Disabled;
10461047

10471048
mlir::RewritePatternSet patterns(context);
10481049
// TODO: right now the patterns are non-conflicting,

flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class AssumedRankOpConversion
152152
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
153153
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
154154
mlir::GreedyRewriteConfig config;
155-
config.enableRegionSimplification = false;
155+
config.enableRegionSimplification =
156+
mlir::GreedySimplifyRegionLevel::Disabled;
156157
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
157158
}
158159
};

flang/lib/Optimizer/Transforms/StackArrays.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
767767
mlir::RewritePatternSet patterns(&context);
768768
mlir::GreedyRewriteConfig config;
769769
// prevent the pattern driver form merging blocks
770-
config.enableRegionSimplification = false;
770+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
771771

772772
patterns.insert<AllocMemConversion>(&context, *candidateOps);
773773
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
2929
ExistingOps
3030
};
3131

32+
enum class GreedySimplifyRegionLevel {
33+
/// Disable region control-flow simplification.
34+
Disabled,
35+
/// Run the normal simplification (e.g. dead args elimination).
36+
Normal,
37+
/// Run extra simplificiations (e.g. block merging), these can be
38+
/// more costly or have some tradeoffs associated.
39+
Aggressive
40+
};
41+
3242
/// This class allows control over how the GreedyPatternRewriteDriver works.
3343
class GreedyRewriteConfig {
3444
public:
@@ -45,7 +55,8 @@ class GreedyRewriteConfig {
4555
/// patterns.
4656
///
4757
/// Note: Only applicable when simplifying entire regions.
48-
bool enableRegionSimplification = true;
58+
GreedySimplifyRegionLevel enableRegionSimplification =
59+
GreedySimplifyRegionLevel::Aggressive;
4960

5061
/// This specifies the maximum number of times the rewriter will iterate
5162
/// between applying patterns and simplifying regions. Use `kNoLimit` to

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_TRANSFORMS_PASSES_H
1616

1717
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819
#include "mlir/Transforms/LocationSnapshot.h"
1920
#include "mlir/Transforms/ViewOpGraph.h"
2021
#include "llvm/Support/Debug.h"

mlir/include/mlir/Transforms/Passes.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
3232
Option<"topDownProcessingEnabled", "top-down", "bool",
3333
/*default=*/"true",
3434
"Seed the worklist in general top-down order">,
35-
Option<"enableRegionSimplification", "region-simplify", "bool",
36-
/*default=*/"true",
37-
"Perform control flow optimizations to the region tree">,
35+
Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
36+
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
37+
"Perform control flow optimizations to the region tree",
38+
[{::llvm::cl::values(
39+
clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
40+
"Don't run any control-flow simplification."),
41+
clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
42+
"Perform simple control-flow simplifications (e.g. dead args elimination)."),
43+
clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
44+
"Perform aggressive control-flow simplification (e.g. block merging).")
45+
)}]>,
3846
Option<"maxIterations", "max-iterations", "int64_t",
3947
/*default=*/"10",
4048
"Max. iterations between applying patterns / simplifying regions">,

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
7474
/// elimination, as well as some other DCE. This function returns success if any
7575
/// of the regions were simplified, failure otherwise. The provided rewriter is
7676
/// used to notify callers of operation and block deletion.
77+
/// Structurally similar blocks will be merged if the `mergeBlock` argument is
78+
/// true. Note this can lead to merged blocks with extra arguments.
7779
LogicalResult simplifyRegions(RewriterBase &rewriter,
78-
MutableArrayRef<Region> regions);
80+
MutableArrayRef<Region> regions,
81+
bool mergeBlocks = true);
7982

8083
/// Erase the unreachable blocks within the provided regions. Returns success
8184
/// if any blocks were erased, failure otherwise.

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
875875

876876
// After applying patterns, make sure that the CFG of each of the
877877
// regions is kept up to date.
878-
if (config.enableRegionSimplification)
879-
continueRewrites |= succeeded(simplifyRegions(rewriter, region));
878+
if (config.enableRegionSimplification !=
879+
GreedySimplifyRegionLevel::Disabled) {
880+
continueRewrites |= succeeded(simplifyRegions(
881+
rewriter, region,
882+
/*mergeBlocks=*/config.enableRegionSimplification ==
883+
GreedySimplifyRegionLevel::Aggressive));
884+
}
880885
},
881886
{&region}, iteration);
882887
} while (continueRewrites);

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,11 +827,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
827827
/// elimination, as well as some other DCE. This function returns success if any
828828
/// of the regions were simplified, failure otherwise.
829829
LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
830-
MutableArrayRef<Region> regions) {
830+
MutableArrayRef<Region> regions,
831+
bool mergeBlocks) {
831832
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
832833
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
833-
bool mergedIdenticalBlocks =
834-
succeeded(mergeIdenticalBlocks(rewriter, regions));
834+
bool mergedIdenticalBlocks = false;
835+
if (mergeBlocks)
836+
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
835837
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
836838
mergedIdenticalBlocks);
837839
}

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s
22

33
//===----------------------------------------------------------------------===//
44
// spirv.AccessChain

mlir/test/Pass/run-reproducer.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ func.func @bar() {
1414
external_resources: {
1515
mlir_reproducer: {
1616
verify_each: true,
17-
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
18-
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
17+
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
18+
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
1919
disable_threading: true
2020
}
2121
}

mlir/test/Transforms/canonicalize-block-merge.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s
22

33
// Check the simple case of single operation blocks with a return.
44

@@ -275,3 +275,18 @@ func.func @mismatch_dominance() -> i32 {
275275
^bb4(%3: i32):
276276
return %3 : i32
277277
}
278+
279+
// CHECK-LABEL: func @dead_dealloc_fold_multi_use
280+
func.func @dead_dealloc_fold_multi_use(%cond : i1) {
281+
// CHECK-NEXT: return
282+
%a = memref.alloc() : memref<4xf32>
283+
cf.cond_br %cond, ^bb1, ^bb2
284+
285+
^bb1:
286+
memref.dealloc %a: memref<4xf32>
287+
return
288+
289+
^bb2:
290+
memref.dealloc %a: memref<4xf32>
291+
return
292+
}

mlir/test/Transforms/canonicalize-dce.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s
22

33
// Test case: Simple case of deleting a dead pure op.
44

mlir/test/Transforms/canonicalize.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {
387387

388388
// CHECK-LABEL: func @dead_dealloc_fold_multi_use
389389
func.func @dead_dealloc_fold_multi_use(%cond : i1) {
390-
// CHECK-NEXT: return
390+
// CHECK-NOT: alloc
391391
%a = memref.alloc() : memref<4xf32>
392+
// CHECK: cond_br
392393
cf.cond_br %cond, ^bb1, ^bb2
393394

394395
^bb1:
396+
// CHECK-NOT: alloc
395397
memref.dealloc %a: memref<4xf32>
398+
// CHECK: return
396399
return
397400

398401
^bb2:
402+
// CHECK-NOT: alloc
399403
memref.dealloc %a: memref<4xf32>
404+
// CHECK: return
400405
return
401406
}
402407

mlir/test/Transforms/test-canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
2-
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
2+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
33

44
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
55
func.func @remove_op_with_inner_ops_pattern() {

0 commit comments

Comments
 (0)