Skip to content

[mlir][Transforms] GreedyPatternRewriteDriver: Check for out-of-scope IR modifications #76219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,29 @@ class GreedyRewriteConfig {

static constexpr int64_t kNoLimit = -1;

/// Only ops within the scope are added to the worklist. If no scope is
/// specified, the closest enclosing region around the initial list of ops
/// (or the specified region, depending on which greedy rewrite entry point
/// is used) is used as a scope.
/// Only ops within the scope are allowed to be modified and are added to the
/// worklist.
///
/// If out-of-scope IR is modified, an assertion will fail inside the greedy
/// pattern rewrite driver if expensive checks are enabled (as long as rewrite
/// patterns use the rewriter API correctly). We also allow attribute
/// modifications of the op that owns the scope region. (This is consistent
/// with the fact that passes are allowed to modify attributes of the
/// operation that they operate on.)
///
/// The scope region must be isolated from above. This ensures that
/// out-of-scope ops are not affected by rewrites.
///
/// If no scope is specified, it is set as follows:
/// * Single op greedy rewrite: a greedy rewrite is performed for every region
/// of the op. (See below.) The scope is set to the respective region of
/// each greedy write.
/// * Multi op greedy rewrite: the closest enclosing IsolatedFromAbove region
/// around the initial list of ops. If there is no such region, the scope
/// is `nullptr`. This is because multi-op greedy rewrites are allowed to
/// modify top-level ops. (They are not allowed to erase top-level ops.)
/// * Single region greedy rewrite: the specified region. (The op that owns
/// the region must be isolated from above.)
Region *scope = nullptr;

/// Strict mode can restrict the ops that are added to the worklist during
Expand Down Expand Up @@ -124,11 +143,9 @@ applyPatternsAndFoldGreedily(Region &region,
/// This overload runs a separate greedy rewrite for each region of the
/// specified op. A region scope can be set in the configuration parameter. By
/// default, the scope is set to the region of the current greedy rewrite. Only
/// in-scope ops are added to the worklist and only in-scope ops and the
/// specified op itself are allowed to be modified by the patterns.
///
/// Note: The specified op may be modified, but it may not be removed by the
/// patterns.
/// in-scope ops are added to the worklist and only in-scope ops are allowed to
/// be modified by the patterns. In addition, the attributes of the op that
/// owns the scope region may also be modified.
///
/// Returns "success" if the iterative process converged (i.e., fixpoint was
/// reached) and no more patterns can be matched within the region. `changed`
Expand Down
103 changes: 85 additions & 18 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;

private:
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
/// Return "true" if the given op is guaranteed to be out of scope.
bool isOutOfScope(Operation *op) const;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
/// operation is modified or removed, as it may trigger further
Expand Down Expand Up @@ -375,6 +380,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
bool GreedyPatternRewriteDriver::isOutOfScope(Operation *op) const {
// No op is out of scope if no scope was set.
if (!config.scope)
return false;
// Check if the given op and the scope region are part of the same IR tree.
// The parent op into which the given op was inserted may be unlinked, in
// which case we do not consider the given op to be out of scope. (That parent
// op will likely be inserted later, together with all its nested ops.)
Region *r = config.scope;
while (r) {
if (r->findAncestorOpInRegion(*op) || r->getParentOp() == op)
break;
r = r->getParentRegion();
}
if (!r)
return false;
// Op is out of scope if it is not within the scope region.
return !config.scope->findAncestorOpInRegion(*op);
}
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

bool GreedyPatternRewriteDriver::processWorklist() {
#ifndef NDEBUG
const char *logLineComment =
Expand Down Expand Up @@ -579,6 +606,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
addSingleOpToWorklist(op);
return;
}
// TODO: Unlinked ops are currently not added to the worklist if a `scope`
// is specified.
if (region == nullptr)
return;
} while ((op = region->getParentOp()));
Expand All @@ -600,6 +629,13 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && isOutOfScope(op))
llvm::report_fatal_error(
"greedy pattern rewrite inserted op into region that is out of scope");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

if (config.listener)
config.listener->notifyOperationInserted(op);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
Expand All @@ -608,10 +644,24 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
}

void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
// TODO: This notification should also be triggered when moving an op into
// this op.
LLVM_DEBUG({
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
// Modifying attributes of the op that owns the scope region is allowed
// when using the applyPatternsAndFoldGreedily(Operation *) entry point.
if (op != config.scope->getParentOp() && isOutOfScope(op)) {
llvm::report_fatal_error("greedy pattern rewrite modified op within "
"region that is out of scope");
}
}
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

if (config.listener)
config.listener->notifyOperationModified(op);
addToWorklist(op);
Expand All @@ -637,16 +687,11 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
<< ")\n";
});

#ifndef NDEBUG
// Only ops that are within the configured scope are added to the worklist of
// the greedy pattern rewriter. Moreover, the parent op of the scope region is
// the part of the IR that is taken into account for the "expensive checks".
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
// region, as that would break the worklist handling and the expensive checks.
if (config.scope && config.scope->getParentOp() == op)
llvm_unreachable(
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && isOutOfScope(op))
llvm::report_fatal_error(
"greedy pattern rewrite removed op from region that is out of scope");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

if (config.listener)
config.listener->notifyOperationRemoved(op);
Expand Down Expand Up @@ -800,16 +845,22 @@ LogicalResult
mlir::applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) {
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");

// Set scope if not specified.
if (!config.scope)
config.scope = &region;

// Make sure that the specified region on which the greedy rewrite should
// operate is in scope.
assert(config.scope->isAncestor(&region) && "input region must be in scope");

// The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
// are out of scope are never added to the worklist and any out-of-scope IR
// modifications trigger an assertion when expensive expensive checks are
// enabled (as long as the rewriter API is used correctly).
assert(
config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"greedy pattern rewrite scope must be IsolatedFromAbove");

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error(
Expand Down Expand Up @@ -886,7 +937,8 @@ LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
return success(worklist.empty());
}

/// Find the region that is the closest common ancestor of all given ops.
/// Find the IsolateFromAbove region that is the closest common ancestor of all
/// given ops.
///
/// Note: This function returns `nullptr` if there is a top-level op among the
/// given list of ops.
Expand All @@ -896,6 +948,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
if (ops.size() == 1)
return ops.front()->getParentRegion();

// Find the closest region that contains all ops.
Region *region = ops.front()->getParentRegion();
ops = ops.drop_front();
int sz = ops.size();
Expand All @@ -912,6 +965,12 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
break;
region = region->getParentRegion();
}

// Find the closest IsolatedFromAbove region.
while (region &&
!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
region = region->getParentRegion();

return region;
}

Expand All @@ -932,8 +991,16 @@ LogicalResult mlir::applyOpPatternsAndFold(
// there is a top-level op among `ops`.
config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
// If a scope was provided, make sure that it is IsolatedFromAbove and that
// all ops are in scope.
#ifndef NDEBUG
// The scope of a greedy pattern rewrite must be IsolatedFromAbove. Ops that
// are out of scope are never added to the worklist and any out-of-scope IR
// modifications trigger an assertion when expensive expensive checks are
// enabled (as long as the rewriter API is used correctly).
assert(
config.scope->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"greedy pattern rewrite scope must be IsolatedFromAbove");
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
Expand Down