Skip to content

[mlir][IR] Trigger notifyOperationRemoved callback for nested ops #66771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/RegionKindInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
/// not implement the RegionKindInterface.
bool mayHaveSSADominance(Region &region);

/// Return "true" if the given region may be a graph region without SSA
/// dominance. This function returns "true" in case the owner op is an
/// unregistered op. It returns "false" if it is a registered op that does not
/// implement the RegionKindInterface.
bool mayBeGraphRegion(Region &region);

} // namespace mlir

#include "mlir/IR/RegionKindInterface.h.inc"
Expand Down
9 changes: 3 additions & 6 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {

protected:
void notifyOperationRemoved(Operation *op) override {
// TODO: Walk can be removed when D144193 has landed.
op->walk([&](Operation *op) {
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
});
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
}

void notifyOperationInserted(Operation *op) override {
Expand Down
72 changes: 68 additions & 4 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"

using namespace mlir;

Expand Down Expand Up @@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
for (auto it : llvm::zip(op->getResults(), newValues))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));

// Erase the op.
// Erase op and notify listener.
eraseOp(op);
}

Expand All @@ -295,17 +297,79 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));

// Erase the old op.
// Erase op and notify listener.
eraseOp(op);
}

/// This method erases an operation that is known to have no uses. The uses of
/// the given operation *must* be known to be dead.
void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);

// Fast path: If no listener is attached, the op can be dropped in one go.
if (!rewriteListener) {
op->erase();
return;
}

// Helper function that erases a single op.
auto eraseSingleOp = [&](Operation *op) {
#ifndef NDEBUG
// All nested ops should have been erased already.
assert(
llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
"expected empty regions");
// All users should have been erased already if the op is in a region with
// SSA dominance.
if (!op->use_empty() && op->getParentOp())
assert(mayBeGraphRegion(*op->getParentRegion()) &&
"expected that op has no uses");
#endif // NDEBUG
rewriteListener->notifyOperationRemoved(op);
op->erase();

// Explicitly drop all uses in case the op is in a graph region.
op->dropAllUses();
op->erase();
};

// Nested ops must be erased one-by-one, so that listeners have a consistent
// view of the IR every time a notification is triggered. Users must be
// erased before definitions. I.e., post-order, reverse dominance.
std::function<void(Operation *)> eraseTree = [&](Operation *op) {
// Erase nested ops.
for (Region &r : llvm::reverse(op->getRegions())) {
// Erase all blocks in the right order. Successors should be erased
// before predecessors because successor blocks may use values defined
// in predecessor blocks. A post-order traversal of blocks within a
// region visits successors before predecessors. Repeat the traversal
// until the region is empty. (The block graph could be disconnected.)
while (!r.empty()) {
SmallVector<Block *> erasedBlocks;
for (Block *b : llvm::post_order(&r.front())) {
// Visit ops in reverse order.
for (Operation &op :
llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
eraseTree(&op);
// Do not erase the block immediately. This is not supprted by the
// post_order iterator.
erasedBlocks.push_back(b);
}
for (Block *b : erasedBlocks) {
// Explicitly drop all uses in case there is a cycle in the block
// graph.
for (BlockArgument bbArg : b->getArguments())
bbArg.dropAllUses();
b->dropAllUses();
b->erase();
}
}
}
// Then erase the enclosing op.
eraseSingleOp(op);
};

eraseTree(op);
}

void RewriterBase::eraseBlock(Block *block) {
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/IR/RegionKindInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@ using namespace mlir;
#include "mlir/IR/RegionKindInterface.cpp.inc"

bool mlir::mayHaveSSADominance(Region &region) {
auto regionKindOp =
dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return true;
return regionKindOp.hasSSADominance(region.getRegionNumber());
}

bool mlir::mayBeGraphRegion(Region &region) {
if (!region.getParentOp()->isRegistered())
return true;
auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return false;
return !regionKindOp.hasSSADominance(region.getRegionNumber());
}
9 changes: 3 additions & 6 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {

// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
notifyOperationRemoved(op);
op->erase();
eraseOp(op);
changed = true;

LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
Expand Down Expand Up @@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
config.listener->notifyOperationRemoved(op);

addOperandsToWorklist(op->getOperands());
op->walk([this](Operation *operation) {
worklist.remove(operation);
folder.notifyRemoval(operation);
});
worklist.remove(op);
folder.notifyRemoval(op);

if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
Expand Down
160 changes: 153 additions & 7 deletions mlir/test/Transforms/test-strict-pattern-driver.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

// CHECK-EN-LABEL: func @test_erase
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: test.arg0
// CHECK-EN: test.arg1
// CHECK-EN-NOT: test.erase_op
// CHECK-EN: "test.arg0"
// CHECK-EN: "test.arg1"
// CHECK-EN-NOT: "test.erase_op"
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
Expand Down Expand Up @@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {

// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: test.replace_with_new_op
// CHECK-EN-NOT: test.erase_op
// CHECK-EN-NOT: "test.replace_with_new_op"
// CHECK-EN-NOT: "test.erase_op"

// CHECK-EX-LABEL: func @test_replace_with_erase_op
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EX-NOT: test.replace_with_new_op
// CHECK-EX: test.erase_op
// CHECK-EX-NOT: "test.replace_with_new_op"
// CHECK-EX: "test.erase_op"
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
Expand All @@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
// in turn, replaces the successor with bb3.
"test.implicit_change_op"() [^bb1] : () -> ()
}

// -----

// CHECK-AN: notifyOperationRemoved: test.foo_b
// CHECK-AN: notifyOperationRemoved: test.foo_a
// CHECK-AN: notifyOperationRemoved: test.graph_region
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_graph_region()
// CHECK-AN-NEXT: return
func.func @test_remove_graph_region() {
"test.erase_op"() ({
test.graph_region {
%0 = "test.foo_a"(%1) : (i1) -> (i1)
%1 = "test.foo_b"(%0) : (i1) -> (i1)
}
}) : () -> ()
return
}

// -----

// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_cyclic_blocks() {
"test.erase_op"() ({
%x = "test.dummy_op"() : () -> (i1)
cf.br ^bb1(%x: i1)
^bb1(%arg0: i1):
"test.foo"(%x) : (i1) -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.bar"(%x) : (i1) -> ()
cf.br ^bb1(%arg1: i1)
}) : () -> ()
return
}

// -----

// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: test.qux
// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
// CHECK-AN: notifyOperationRemoved: test.nested_dummy
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_dead_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_dead_blocks() {
"test.erase_op"() ({
"test.dummy_op"() : () -> (i1)
// The following blocks are not reachable. Still, ^bb2 should be deleted
// befire ^bb1.
^bb1(%arg0: i1):
"test.foo"() : () -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.nested_dummy"() ({
"test.qux"() : () -> ()
// The following block is unreachable.
^bb3:
"test.qux_unreachable"() : () -> ()
}) : () -> ()
"test.bar"() : () -> ()
}) : () -> ()
return
}

// -----

// test.nested_* must be deleted before test.foo.
// test.bar must be deleted before test.foo.

// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_b
// CHECK-AN: notifyOperationRemoved: test.nested_a
// CHECK-AN: notifyOperationRemoved: test.nested_d
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_e
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.nested_c
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.dummy_op
// CHECK-AN: notifyOperationRemoved: test.erase_op
// CHECK-AN-LABEL: func @test_remove_nested_ops()
// CHECK-AN-NEXT: return
func.func @test_remove_nested_ops() {
"test.erase_op"() ({
%x = "test.dummy_op"() : () -> (i1)
cf.br ^bb1(%x: i1)
^bb1(%arg0: i1):
"test.foo"() ({
"test.nested_a"() : () -> ()
"test.nested_b"() : () -> ()
^dead1:
"test.nested_c"() : () -> ()
cf.br ^dead3
^dead2:
"test.nested_d"() : () -> ()
^dead3:
"test.nested_e"() : () -> ()
cf.br ^dead2
}) : () -> ()
cf.br ^bb2(%arg0: i1)
^bb2(%arg1: i1):
"test.bar"(%x) : (i1) -> ()
cf.br ^bb1(%arg1: i1)
}) : () -> ()
return
}

// -----

// CHECK-AN: notifyOperationRemoved: test.qux
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.foo
// CHECK-AN: notifyOperationRemoved: cf.br
// CHECK-AN: notifyOperationRemoved: test.bar
// CHECK-AN: notifyOperationRemoved: cf.cond_br
// CHECK-AN-LABEL: func @test_remove_diamond(
// CHECK-AN-NEXT: return
func.func @test_remove_diamond(%c: i1) {
"test.erase_op"() ({
cf.cond_br %c, ^bb1, ^bb2
^bb1:
"test.foo"() : () -> ()
cf.br ^bb3
^bb2:
"test.bar"() : () -> ()
cf.br ^bb3
^bb3:
"test.qux"() : () -> ()
}) : () -> ()
return
}
8 changes: 8 additions & 0 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ struct TestPatternDriver
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};

struct DumpNotifications : public RewriterBase::Listener {
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
};

struct TestStrictPatternDriver
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
public:
Expand Down Expand Up @@ -275,7 +281,9 @@ struct TestStrictPatternDriver
}
});

DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
config.listener = &dumpNotifications;
if (strictMode == "AnyOp") {
config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
Expand Down