Skip to content

[mlir] Do not merge blocks during canonicalization by default #95057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ namespace fir {
static void addCanonicalizerPassWithoutRegionSimplification(
mlir::OpPassManager &pm) {
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
pm.addPass(mlir::createCanonicalizerPass(config));
}

Expand Down Expand Up @@ -260,7 +260,7 @@ inline void createDefaultFIROptimizerPassPipeline(

// simplify the IR
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
pm.addPass(mlir::createCSEPass());
fir::addAVC(pm, pc.OptLevel);
addNestedPassToAllTopLevelOperations(pm, fir::createCharacterConversion);
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class InlineElementalsPass

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification = false;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ class LowerHLFIRIntrinsics
// Pattern rewriting only requires that the resulting IR is still valid
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification = false;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
module, std::move(patterns), config))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,8 @@ class OptimizedBufferizationPass

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification = false;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

mlir::RewritePatternSet patterns(context);
// TODO: right now the patterns are non-conflicting,
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class AssumedRankOpConversion
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
patterns.insert<IsAssumedSizeConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = false;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
}
};
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
mlir::RewritePatternSet patterns(&context);
mlir::GreedyRewriteConfig config;
// prevent the pattern driver form merging blocks
config.enableRegionSimplification = false;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
Expand Down
13 changes: 12 additions & 1 deletion mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ enum class GreedyRewriteStrictness {
ExistingOps
};

enum class GreedySimplifyRegionLevel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment for this.

/// Disable region control-flow simplification.
Disabled,
/// Run the normal simplification (e.g. dead args elimination).
Normal,
/// Run extra simplificiations (e.g. block merging), these can be
/// more costly or have some tradeoffs associated.
Aggressive
};

/// This class allows control over how the GreedyPatternRewriteDriver works.
class GreedyRewriteConfig {
public:
Expand All @@ -45,7 +55,8 @@ class GreedyRewriteConfig {
/// patterns.
///
/// Note: Only applicable when simplifying entire regions.
bool enableRegionSimplification = true;
GreedySimplifyRegionLevel enableRegionSimplification =
GreedySimplifyRegionLevel::Aggressive;

/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define MLIR_TRANSFORMS_PASSES_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
Expand Down
14 changes: 11 additions & 3 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"topDownProcessingEnabled", "top-down", "bool",
/*default=*/"true",
"Seed the worklist in general top-down order">,
Option<"enableRegionSimplification", "region-simplify", "bool",
/*default=*/"true",
"Perform control flow optimizations to the region tree">,
Option<"enableRegionSimplification", "region-simplify", "mlir::GreedySimplifyRegionLevel",
/*default=*/"mlir::GreedySimplifyRegionLevel::Normal",
"Perform control flow optimizations to the region tree",
[{::llvm::cl::values(
clEnumValN(mlir::GreedySimplifyRegionLevel::Disabled, "disabled",
"Don't run any control-flow simplification."),
clEnumValN(mlir::GreedySimplifyRegionLevel::Normal, "normal",
"Perform simple control-flow simplifications (e.g. dead args elimination)."),
clEnumValN(mlir::GreedySimplifyRegionLevel::Aggressive, "aggressive",
"Perform aggressive control-flow simplification (e.g. block merging).")
)}]>,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Max. iterations between applying patterns / simplifying regions">,
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise. The provided rewriter is
/// used to notify callers of operation and block deletion.
/// Structurally similar blocks will be merged if the `mergeBlock` argument is
/// true. Note this can lead to merged blocks with extra arguments.
LogicalResult simplifyRegions(RewriterBase &rewriter,
MutableArrayRef<Region> regions);
MutableArrayRef<Region> regions,
bool mergeBlocks = true);

/// Erase the unreachable blocks within the provided regions. Returns success
/// if any blocks were erased, failure otherwise.
Expand Down
9 changes: 7 additions & 2 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {

// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
continueRewrites |= succeeded(simplifyRegions(rewriter, region));
if (config.enableRegionSimplification !=
GreedySimplifyRegionLevel::Disabled) {
continueRewrites |= succeeded(simplifyRegions(
rewriter, region,
/*mergeBlocks=*/config.enableRegionSimplification ==
GreedySimplifyRegionLevel::Aggressive));
}
},
{&region}, iteration);
} while (continueRewrites);
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,11 +827,13 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
/// elimination, as well as some other DCE. This function returns success if any
/// of the regions were simplified, failure otherwise.
LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
MutableArrayRef<Region> regions,
bool mergeBlocks) {
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks =
succeeded(mergeIdenticalBlocks(rewriter, regions));
bool mergedIdenticalBlocks = false;
if (mergeBlocks)
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
Comment on lines +834 to +836
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be written in one line:

bool mergedIdenticalBlocks = mergeBlocks && succeeded(...);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the short circuit with side effects: this would obscure the control flow IMO.

return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s
// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence region-simplify=aggressive}))' | FileCheck %s

//===----------------------------------------------------------------------===//
// spirv.AccessChain
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Pass/run-reproducer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
// CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))",
disable_threading: true
}
}
Expand Down
17 changes: 16 additions & 1 deletion mlir/test/Transforms/canonicalize-block-merge.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize))' -split-input-file | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' -split-input-file | FileCheck %s

// Check the simple case of single operation blocks with a return.

Expand Down Expand Up @@ -275,3 +275,18 @@ func.func @mismatch_dominance() -> i32 {
^bb4(%3: i32):
return %3 : i32
}

// CHECK-LABEL: func @dead_dealloc_fold_multi_use
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop extra line.

func.func @dead_dealloc_fold_multi_use(%cond : i1) {
// CHECK-NEXT: return
%a = memref.alloc() : memref<4xf32>
cf.cond_br %cond, ^bb1, ^bb2

^bb1:
memref.dealloc %a: memref<4xf32>
return

^bb2:
memref.dealloc %a: memref<4xf32>
return
}
2 changes: 1 addition & 1 deletion mlir/test/Transforms/canonicalize-dce.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=aggressive}))' | FileCheck %s

// Test case: Simple case of deleting a dead pure op.

Expand Down
7 changes: 6 additions & 1 deletion mlir/test/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,21 @@ func.func @dead_dealloc_fold() {

// CHECK-LABEL: func @dead_dealloc_fold_multi_use
func.func @dead_dealloc_fold_multi_use(%cond : i1) {
// CHECK-NEXT: return
// CHECK-NOT: alloc
%a = memref.alloc() : memref<4xf32>
// CHECK: cond_br
cf.cond_br %cond, ^bb1, ^bb2

^bb1:
// CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
// CHECK: return
return

^bb2:
// CHECK-NOT: alloc
memref.dealloc %a: memref<4xf32>
// CHECK: return
return
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Transforms/test-canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS

// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
func.func @remove_op_with_inner_ops_pattern() {
Expand Down
Loading