Skip to content

Commit ce33a69

Browse files
[mlir][Transforms] GreedyPatternRewriteDriver: Check for out-of-scope IR modifications
This commit adds an additional "expensive check" (only enabled with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`) that looks for out-of-scope IR modifications. `GreedyRewriteConfig::scope` specifies the `Region *` within which the greedy pattern rewrite operates. Operations that are out-of-scope are not added to the worklist. The new expensive check triggers an fatal error if: * Op is inserted into out-of-scope region. * Op is removed from out-of-scope region. * Op is modified in out-of-scope region. This change also tightens the greedy pattern rewriter entry points and makes sure that the specified `scope` is an `IsolatedFromAbove` region. Note: `TileAllocation` (`ArmSME` dialect) must now be a module pass because it modifies `func.func` ops (adds attributes). This is forbidden for function passes (in which the scope of the greedy rewrite is set to the region of the function by default) because only function bodies are allowed to be modified. (TODO: Should we allow this? Is there something special about functions?)
1 parent bd2a6ef commit ce33a69

File tree

2 files changed

+111
-27
lines changed

2 files changed

+111
-27
lines changed

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,29 @@ class GreedyRewriteConfig {
6060

6161
static constexpr int64_t kNoLimit = -1;
6262

63-
/// Only ops within the scope are added to the worklist. If no scope is
64-
/// specified, the closest enclosing region around the initial list of ops
65-
/// (or the specified region, depending on which greedy rewrite entry point
66-
/// is used) is used as a scope.
63+
/// Only ops within the scope are allowed to be modified and are added to the
64+
/// worklist.
65+
///
66+
/// If out-of-scope IR is modified, an assertion will fail inside the greedy
67+
/// pattern rewrite driver if expensive checks are enabled (as long as rewrite
68+
/// patterns use the rewriter API correctly). We also allow attribute
69+
/// modifications of the op that owns the scope region. (This is consistent
70+
/// with the fact that passes are allowed to modify attributes of the
71+
/// operation that they operate on.)
72+
///
73+
/// The scope region must be isolated from above. This ensures that
74+
/// out-of-scope ops are not affected by rewrites.
75+
///
76+
/// If no scope is specified, it is set as follows:
77+
/// * Single op greedy rewrite: a greedy rewrite is performed for every region
78+
/// of the op. (See below.) The scope is set to the respective region of
79+
/// each greedy write.
80+
/// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
81+
/// around the initial list of ops. If there is no such region, the scope
82+
/// is `nullptr`. This is because multi-op greedy rewrites are allowed to
83+
/// modify top-level ops. (They are not allowed to erase top-level ops.)
84+
/// * Single region greedy rewrite: the specified region. (The op that owns
85+
/// the region must be isolated from above.)
6786
Region *scope = nullptr;
6887

6988
/// Strict mode can restrict the ops that are added to the worklist during
@@ -124,11 +143,9 @@ applyPatternsAndFoldGreedily(Region &region,
124143
/// This overload runs a separate greedy rewrite for each region of the
125144
/// specified op. A region scope can be set in the configuration parameter. By
126145
/// default, the scope is set to the region of the current greedy rewrite. Only
127-
/// in-scope ops are added to the worklist and only in-scope ops and the
128-
/// specified op itself are allowed to be modified by the patterns.
129-
///
130-
/// Note: The specified op may be modified, but it may not be removed by the
131-
/// patterns.
146+
/// in-scope ops are added to the worklist and only in-scope ops are allowed to
147+
/// be modified by the patterns. In addition, the attributes of the op that
148+
/// owns the scope region may also be modified.
132149
///
133150
/// Returns "success" if the iterative process converged (i.e., fixpoint was
134151
/// reached) and no more patterns can be matched within the region. `changed`

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
324324
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
325325

326326
private:
327+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
328+
/// Return "true" if the given op is guaranteed to be out of scope.
329+
bool isOutOfScope(Operation *op) const;
330+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
331+
327332
/// Look over the provided operands for any defining operations that should
328333
/// be re-added to the worklist. This function should be called when an
329334
/// operation is modified or removed, as it may trigger further
@@ -375,6 +380,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
375380
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
376381
}
377382

383+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
384+
bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
385+
// No op is out of scope if no scope was set.
386+
if (!config.scope)
387+
return false;
388+
// Check if the given op and the scope region are part of the same IR tree.
389+
// The parent op into which the given op was inserted may be unlinked, in
390+
// which case we do not consider the given op to be out of scope. (That parent
391+
// op will likely be inserted later, together with all its nested ops.)
392+
Region *r = config.scope;
393+
while (r) {
394+
if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
395+
break;
396+
r = r->getParentRegion();
397+
}
398+
if (!r)
399+
return false;
400+
// Op is out of scope if it is not within the scope region.
401+
return !config.scope->findAncestorOpInRegion(*op);
402+
}
403+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
404+
378405
bool GreedyPatternRewriteDriver::processWorklist() {
379406
#ifndef NDEBUG
380407
const char *logLineComment =
@@ -579,6 +606,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
579606
addSingleOpToWorklist(op);
580607
return;
581608
}
609+
// TODO: Unlinked ops are currently not added to the worklist if a `scope`
610+
// is specified.
582611
if (region == nullptr)
583612
return;
584613
} while ((op = region->getParentOp()));
@@ -600,6 +629,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
600629
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
601630
<< ")\n";
602631
});
632+
633+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
634+
if (config.scope && isOutOfScope(op))
635+
llvm::report_fatal_error(
636+
"greedy pattern rewrite inserted op into region that is out of scope");
637+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
638+
603639
if (config.listener)
604640
config.listener->notifyOperationInserted(op);
605641
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
@@ -608,10 +644,24 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
608644
}
609645

610646
void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
647+
// TODO: This notification should also be triggered when moving an op into
648+
// this op.
611649
LLVM_DEBUG({
612650
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
613651
<< ")\n";
614652
});
653+
654+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
655+
if (config.scope) {
656+
// Modifying attributes of the op that owns the scope region is allowed
657+
// when using the applyPatternsAndFoldGreedily(Operation *) entry point.
658+
if (op != config.scope->getParentOp() && isOutOfScope(op)) {
659+
llvm::report_fatal_error("greedy pattern rewrite modified op within "
660+
"region that is out of scope");
661+
}
662+
}
663+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
664+
615665
if (config.listener)
616666
config.listener->notifyOperationModified(op);
617667
addToWorklist(op);
@@ -637,16 +687,11 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
637687
<< ")\n";
638688
});
639689

640-
#ifndef NDEBUG
641-
// Only ops that are within the configured scope are added to the worklist of
642-
// the greedy pattern rewriter. Moreover, the parent op of the scope region is
643-
// the part of the IR that is taken into account for the "expensive checks".
644-
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
645-
// region, as that would break the worklist handling and the expensive checks.
646-
if (config.scope && config.scope->getParentOp() == op)
647-
llvm_unreachable(
648-
"scope region must not be erased during greedy pattern rewrite");
649-
#endif // NDEBUG
690+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
691+
if (config.scope && isOutOfScope(op))
692+
llvm::report_fatal_error(
693+
"greedy pattern rewrite removed op from region that is out of scope");
694+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
650695

651696
if (config.listener)
652697
config.listener->notifyOperationRemoved(op);
@@ -800,16 +845,22 @@ LogicalResult
800845
mlir::applyPatternsAndFoldGreedily(Region &region,
801846
const FrozenRewritePatternSet &patterns,
802847
GreedyRewriteConfig config, bool *changed) {
803-
// The top-level operation must be known to be isolated from above to
804-
// prevent performing canonicalizations on operations defined at or above
805-
// the region containing 'op'.
806-
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
807-
"patterns can only be applied to operations IsolatedFromAbove");
808-
809848
// Set scope if not specified.
810849
if (!config.scope)
811850
config.scope = &region;
812851

852+
// Make sure that the specified region on which the greedy rewrite should
853+
// operate is in scope.
854+
assert(config.scope->isAncestor(&region) && "input region must be in scope");
855+
856+
// The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
857+
// are out of scope are never added to the worklist and any out-of-scope IR
858+
// modifications trigger an assertion when expensive expensive checks are
859+
// enabled (as long as the rewriter API is used correctly).
860+
assert(
861+
config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
862+
"greedy pattern rewrite scope must be IsolatedFromAbove");
863+
813864
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
814865
if (failed(verify(config.scope->getParentOp())))
815866
llvm::report_fatal_error(
@@ -886,7 +937,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
886937
return success(worklist.empty());
887938
}
888939

889-
/// Find the region that is the closest common ancestor of all given ops.
940+
/// Find the IsolateFromAbove region that is the closest common ancestor of all
941+
/// given ops.
890942
///
891943
/// Note: This function returns `nullptr` if there is a top-level op among the
892944
/// given list of ops.
@@ -896,6 +948,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
896948
if (ops.size() == 1)
897949
return ops.front()->getParentRegion();
898950

951+
// Find the closest region that contains all ops.
899952
Region *region = ops.front()->getParentRegion();
900953
ops = ops.drop_front();
901954
int sz = ops.size();
@@ -912,6 +965,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
912965
break;
913966
region = region->getParentRegion();
914967
}
968+
969+
// Find the closest IsolatedFromAbove region.
970+
while (region &&
971+
!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
972+
region = region->getParentRegion();
973+
915974
return region;
916975
}
917976

@@ -932,8 +991,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
932991
// there is a top-level op among `ops`.
933992
config.scope = findCommonAncestor(ops);
934993
} else {
935-
// If a scope was provided, make sure that all ops are in scope.
994+
// If a scope was provided, make sure that it is IsolatedFromAbove and that
995+
// all ops are in scope.
936996
#ifndef NDEBUG
997+
// The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
998+
// are out of scope are never added to the worklist and any out-of-scope IR
999+
// modifications trigger an assertion when expensive expensive checks are
1000+
// enabled (as long as the rewriter API is used correctly).
1001+
assert(
1002+
config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1003+
"greedy pattern rewrite scope must be IsolatedFromAbove");
9371004
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
9381005
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
9391006
});

0 commit comments

Comments
 (0)