Skip to content

Commit ba3a9f5

Browse files
committed
[mlir:MultiOpDriver] Add operands to worklist should be checked
Operand's defining op may not be valid for adding to the worklist under stict mode Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D127180
1 parent ff80dc8 commit ba3a9f5

File tree

3 files changed

+124
-41
lines changed

3 files changed

+124
-41
lines changed

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
4343
bool simplify(MutableArrayRef<Region> regions);
4444

4545
/// Add the given operation to the worklist.
46-
void addToWorklist(Operation *op);
46+
virtual void addToWorklist(Operation *op);
4747

4848
/// Pop the next operation from the worklist.
4949
Operation *popFromWorklist();
@@ -60,8 +60,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
6060
// be re-added to the worklist. This function should be called when an
6161
// operation is modified or removed, as it may trigger further
6262
// simplifications.
63-
template <typename Operands>
64-
void addToWorklist(Operands &&operands);
63+
void addOperandsToWorklist(ValueRange operands);
6564

6665
// If an operation is about to be removed, make sure it is not in our
6766
// worklist anymore because we'd get dangling references to it.
@@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
219218
originalOperands.assign(op->operand_begin(), op->operand_end());
220219
auto preReplaceAction = [&](Operation *op) {
221220
// Add the operands to the worklist for visitation.
222-
addToWorklist(originalOperands);
221+
addOperandsToWorklist(originalOperands);
223222

224223
// Add all the users of the result to the worklist so we make sure
225224
// to revisit them.
@@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
327326
addToWorklist(op);
328327
}
329328

330-
template <typename Operands>
331-
void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
329+
void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
332330
for (Value operand : operands) {
333331
// If the use count of this operand is now < 2, we re-add the defining
334332
// operation to the worklist.
@@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
343341
}
344342

345343
void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
346-
addToWorklist(op->getOperands());
344+
addOperandsToWorklist(op->getOperands());
347345
op->walk([this](Operation *operation) {
348346
removeFromWorklist(operation);
349347
folder.notifyRemoval(operation);
@@ -523,22 +521,12 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
523521

524522
bool simplifyLocally(ArrayRef<Operation *> op);
525523

526-
private:
527-
// Look over the provided operands for any defining operations that should
528-
// be re-added to the worklist. This function should be called when an
529-
// operation is modified or removed, as it may trigger further
530-
// simplifications. If `strict` is set to true, only ops in
531-
// `strictModeFilteredOps` are considered.
532-
template <typename Operands>
533-
void addOperandsToWorklist(Operands &&operands) {
534-
for (Value operand : operands) {
535-
if (auto *defOp = operand.getDefiningOp()) {
536-
if (!strictMode || strictModeFilteredOps.contains(defOp))
537-
addToWorklist(defOp);
538-
}
539-
}
524+
void addToWorklist(Operation *op) override {
525+
if (!strictMode || strictModeFilteredOps.contains(op))
526+
GreedyPatternRewriteDriver::addToWorklist(op);
540527
}
541528

529+
private:
542530
void notifyOperationInserted(Operation *op) override {
543531
GreedyPatternRewriteDriver::notifyOperationInserted(op);
544532
if (strictMode)
@@ -551,15 +539,6 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
551539
strictModeFilteredOps.erase(op);
552540
}
553541

554-
void notifyRootReplaced(Operation *op) override {
555-
for (auto result : op->getResults()) {
556-
for (auto *user : result.getUsers()) {
557-
if (!strictMode || strictModeFilteredOps.contains(user))
558-
addToWorklist(user);
559-
}
560-
}
561-
}
562-
563542
/// If `strictMode` is true, any pre-existing ops outside of
564543
/// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
565544
/// If `strictMode` is false, operations that use results of (or supply
@@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
633612

634613
// Add all the users of the result to the worklist so we make sure
635614
// to revisit them.
636-
for (Value result : op->getResults())
637-
for (Operation *userOp : result.getUsers()) {
638-
if (!strictMode || strictModeFilteredOps.contains(userOp))
639-
addToWorklist(userOp);
640-
}
615+
for (Value result : op->getResults()) {
616+
for (Operation *userOp : result.getUsers())
617+
addToWorklist(userOp);
618+
}
619+
641620
notifyOperationRemoved(op);
642621
};
643622

644623
// Add the given operation generated by the folder to the worklist.
645624
auto processGeneratedConstants = [this](Operation *op) {
646-
// Newly created ops are also simplified -- these are also "local".
647-
addToWorklist(op);
648-
// When strict mode is off, we don't need to maintain
649-
// strictModeFilteredOps.
650-
if (strictMode)
651-
strictModeFilteredOps.insert(op);
625+
notifyOperationInserted(op);
652626
};
653627

654628
// Try to fold this op.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
2+
3+
// CHECK-LABEL: @test_erase
4+
func.func @test_erase() {
5+
%0 = "test.arg0"() : () -> (i32)
6+
%1 = "test.arg1"() : () -> (i32)
7+
%erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32)
8+
return
9+
}
10+
11+
// CHECK-LABEL: @test_insert_same_op
12+
func.func @test_insert_same_op() {
13+
%0 = "test.insert_same_op"() : () -> (i32)
14+
return
15+
}
16+
17+
// CHECK-LABEL: @test_replace_with_same_op
18+
func.func @test_replace_with_same_op() {
19+
%0 = "test.replace_with_same_op"() : () -> (i32)
20+
%1 = "test.dummy_user"(%0) : (i32) -> (i32)
21+
%2 = "test.dummy_user"(%0) : (i32) -> (i32)
22+
return
23+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,91 @@ struct TestPatternDriver
176176
llvm::cl::desc("Seed the worklist in general top-down order"),
177177
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
178178
};
179+
180+
struct TestStrictPatternDriver
181+
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
182+
public:
183+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver)
184+
185+
TestStrictPatternDriver() = default;
186+
TestStrictPatternDriver(const TestStrictPatternDriver &other)
187+
: PassWrapper(other) {}
188+
189+
StringRef getArgument() const final { return "test-strict-pattern-driver"; }
190+
StringRef getDescription() const final {
191+
return "Run strict mode of pattern driver";
192+
}
193+
194+
void runOnOperation() override {
195+
mlir::RewritePatternSet patterns(&getContext());
196+
patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
197+
SmallVector<Operation *> ops;
198+
getOperation()->walk([&](Operation *op) {
199+
StringRef opName = op->getName().getStringRef();
200+
if (opName == "test.insert_same_op" ||
201+
opName == "test.replace_with_same_op" || opName == "test.erase_op") {
202+
ops.push_back(op);
203+
}
204+
});
205+
206+
// Check if these transformations introduce visiting of operations that
207+
// are not in the `ops` set (The new created ops are valid). An invalid
208+
// operation will trigger the assertion while processing.
209+
(void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns),
210+
/*strict=*/true);
211+
}
212+
213+
private:
214+
// New inserted operation is valid for further transformation.
215+
class InsertSameOp : public RewritePattern {
216+
public:
217+
InsertSameOp(MLIRContext *context)
218+
: RewritePattern("test.insert_same_op", /*benefit=*/1, context) {}
219+
220+
LogicalResult matchAndRewrite(Operation *op,
221+
PatternRewriter &rewriter) const override {
222+
if (op->hasAttr("skip"))
223+
return failure();
224+
225+
Operation *newOp =
226+
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
227+
op->getOperands(), op->getResultTypes());
228+
op->setAttr("skip", rewriter.getBoolAttr(true));
229+
newOp->setAttr("skip", rewriter.getBoolAttr(true));
230+
231+
return success();
232+
}
233+
};
234+
235+
// Replace an operation may introduce the re-visiting of its users.
236+
class ReplaceWithSameOp : public RewritePattern {
237+
public:
238+
ReplaceWithSameOp(MLIRContext *context)
239+
: RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
240+
241+
LogicalResult matchAndRewrite(Operation *op,
242+
PatternRewriter &rewriter) const override {
243+
Operation *newOp =
244+
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
245+
op->getOperands(), op->getResultTypes());
246+
rewriter.replaceOp(op, newOp->getResults());
247+
return success();
248+
}
249+
};
250+
251+
// Remove an operation may introduce the re-visiting of its opreands.
252+
class EraseOp : public RewritePattern {
253+
public:
254+
EraseOp(MLIRContext *context)
255+
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
256+
LogicalResult matchAndRewrite(Operation *op,
257+
PatternRewriter &rewriter) const override {
258+
rewriter.eraseOp(op);
259+
return success();
260+
}
261+
};
262+
};
263+
179264
} // namespace
180265

181266
//===----------------------------------------------------------------------===//
@@ -1471,6 +1556,7 @@ void registerPatternsTestPass() {
14711556
PassRegistration<TestDerivedAttributeDriver>();
14721557

14731558
PassRegistration<TestPatternDriver>();
1559+
PassRegistration<TestStrictPatternDriver>();
14741560

14751561
PassRegistration<TestLegalizePatternDriver>([] {
14761562
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);

0 commit comments

Comments
 (0)