@@ -43,7 +43,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
43
43
bool simplify (MutableArrayRef<Region> regions);
44
44
45
45
// / Add the given operation to the worklist.
46
- void addToWorklist (Operation *op);
46
+ virtual void addToWorklist (Operation *op);
47
47
48
48
// / Pop the next operation from the worklist.
49
49
Operation *popFromWorklist ();
@@ -60,8 +60,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
60
60
// be re-added to the worklist. This function should be called when an
61
61
// operation is modified or removed, as it may trigger further
62
62
// simplifications.
63
- template <typename Operands>
64
- void addToWorklist (Operands &&operands);
63
+ void addOperandsToWorklist (ValueRange operands);
65
64
66
65
// If an operation is about to be removed, make sure it is not in our
67
66
// worklist anymore because we'd get dangling references to it.
@@ -219,7 +218,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
219
218
originalOperands.assign (op->operand_begin (), op->operand_end ());
220
219
auto preReplaceAction = [&](Operation *op) {
221
220
// Add the operands to the worklist for visitation.
222
- addToWorklist (originalOperands);
221
+ addOperandsToWorklist (originalOperands);
223
222
224
223
// Add all the users of the result to the worklist so we make sure
225
224
// to revisit them.
@@ -327,8 +326,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
327
326
addToWorklist (op);
328
327
}
329
328
330
- template <typename Operands>
331
- void GreedyPatternRewriteDriver::addToWorklist (Operands &&operands) {
329
+ void GreedyPatternRewriteDriver::addOperandsToWorklist (ValueRange operands) {
332
330
for (Value operand : operands) {
333
331
// If the use count of this operand is now < 2, we re-add the defining
334
332
// operation to the worklist.
@@ -343,7 +341,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
343
341
}
344
342
345
343
void GreedyPatternRewriteDriver::notifyOperationRemoved (Operation *op) {
346
- addToWorklist (op->getOperands ());
344
+ addOperandsToWorklist (op->getOperands ());
347
345
op->walk ([this ](Operation *operation) {
348
346
removeFromWorklist (operation);
349
347
folder.notifyRemoval (operation);
@@ -523,22 +521,12 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
523
521
524
522
bool simplifyLocally (ArrayRef<Operation *> op);
525
523
526
- private:
527
- // Look over the provided operands for any defining operations that should
528
- // be re-added to the worklist. This function should be called when an
529
- // operation is modified or removed, as it may trigger further
530
- // simplifications. If `strict` is set to true, only ops in
531
- // `strictModeFilteredOps` are considered.
532
- template <typename Operands>
533
- void addOperandsToWorklist (Operands &&operands) {
534
- for (Value operand : operands) {
535
- if (auto *defOp = operand.getDefiningOp ()) {
536
- if (!strictMode || strictModeFilteredOps.contains (defOp))
537
- addToWorklist (defOp);
538
- }
539
- }
524
+ void addToWorklist (Operation *op) override {
525
+ if (!strictMode || strictModeFilteredOps.contains (op))
526
+ GreedyPatternRewriteDriver::addToWorklist (op);
540
527
}
541
528
529
+ private:
542
530
void notifyOperationInserted (Operation *op) override {
543
531
GreedyPatternRewriteDriver::notifyOperationInserted (op);
544
532
if (strictMode)
@@ -551,15 +539,6 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
551
539
strictModeFilteredOps.erase (op);
552
540
}
553
541
554
- void notifyRootReplaced (Operation *op) override {
555
- for (auto result : op->getResults ()) {
556
- for (auto *user : result.getUsers ()) {
557
- if (!strictMode || strictModeFilteredOps.contains (user))
558
- addToWorklist (user);
559
- }
560
- }
561
- }
562
-
563
542
// / If `strictMode` is true, any pre-existing ops outside of
564
543
// / `strictModeFilteredOps` remain completely untouched by the rewrite driver.
565
544
// / If `strictMode` is false, operations that use results of (or supply
@@ -633,22 +612,17 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
633
612
634
613
// Add all the users of the result to the worklist so we make sure
635
614
// to revisit them.
636
- for (Value result : op->getResults ())
637
- for (Operation *userOp : result.getUsers ()) {
638
- if (!strictMode || strictModeFilteredOps. contains ( userOp))
639
- addToWorklist (userOp);
640
- }
615
+ for (Value result : op->getResults ()) {
616
+ for (Operation *userOp : result.getUsers ())
617
+ addToWorklist ( userOp);
618
+ }
619
+
641
620
notifyOperationRemoved (op);
642
621
};
643
622
644
623
// Add the given operation generated by the folder to the worklist.
645
624
auto processGeneratedConstants = [this ](Operation *op) {
646
- // Newly created ops are also simplified -- these are also "local".
647
- addToWorklist (op);
648
- // When strict mode is off, we don't need to maintain
649
- // strictModeFilteredOps.
650
- if (strictMode)
651
- strictModeFilteredOps.insert (op);
625
+ notifyOperationInserted (op);
652
626
};
653
627
654
628
// Try to fold this op.
0 commit comments