Skip to content

Commit e8ae0e7

Browse files
[mlir][transform] TrackingListener: Improve dead handles detection (#74290)
The tracking listener should not report op replacement errors for payload ops that are not mapped to any live handles. The handle liveless analysis did not work properly with transform IR that has named sequences. A handle is live if it has a user after the transform op that is currently being applied. With named sequences, we need to maintain a stack of currently applied transform ops. That stack already exists (`regionStack`), the only thing that's missing is the current transform op for each stack frame. This commit fixes #72931.
1 parent c630f95 commit e8ae0e7

File tree

3 files changed

+112
-55
lines changed

3 files changed

+112
-55
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,8 @@ class TransformState {
310310
/// with the type of the handle value.
311311
LogicalResult mapBlockArguments(BlockArgument argument,
312312
ArrayRef<Operation *> operations) {
313-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
314-
assert(argument.getParentRegion() == regionStack.back() &&
313+
assert(argument.getParentRegion() == regionStack.back()->region &&
315314
"mapping block arguments from a region other than the active one");
316-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
317315
return setPayloadOps(argument, operations);
318316
}
319317
LogicalResult mapBlockArgument(BlockArgument argument,
@@ -350,9 +348,7 @@ class TransformState {
350348
std::make_pair(&region, std::make_unique<Mappings>()));
351349
assert(res.second && "the region scope is already present");
352350
(void)res;
353-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
354-
state.regionStack.push_back(&region);
355-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
351+
state.regionStack.push_back(this);
356352
}
357353

358354
/// Back-reference to the transform state.
@@ -361,7 +357,10 @@ class TransformState {
361357
/// The region this scope is associated with.
362358
Region *region;
363359

364-
friend RegionScope TransformState::make_region_scope(Region &);
360+
/// The transform op within this region that is currently being applied.
361+
TransformOpInterface currentTransform;
362+
363+
friend class transform::TransformState;
365364
};
366365
friend class RegionScope;
367366

@@ -784,12 +783,14 @@ class TransformState {
784783
/// location.
785784
InvalidatedHandleMap invalidatedHandles;
786785

787-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
788786
/// A stack of nested regions that are being processed in the transform IR.
789787
/// Each region must be an ancestor of the following regions in this list.
790788
/// These are also the keys for "mappings".
791-
SmallVector<Region *> regionStack;
792-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
789+
SmallVector<RegionScope *> regionStack;
790+
791+
/// The top-level region scope. The first (bottom) element of `regionStack`
792+
/// is the top-level region scope object.
793+
std::unique_ptr<RegionScope> topLevelRegionScope;
793794
};
794795

795796
/// Local mapping between values defined by a specific op implementing the
@@ -926,8 +927,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
926927
class TrackingListener : public RewriterBase::Listener,
927928
public TransformState::Extension {
928929
public:
930+
/// A function that returns "true" for handles that do not have to be updated.
931+
using SkipHandleFn = std::function<bool(Value)>;
932+
929933
/// Create a new TrackingListener for usage in the specified transform op.
930-
TrackingListener(TransformState &state, TransformOpInterface op);
934+
/// Optionally, a function can be specified to identify handles that should
935+
/// do not have to be updated.
936+
TrackingListener(TransformState &state, TransformOpInterface op,
937+
SkipHandleFn skipHandleFn = nullptr);
931938

932939
protected:
933940
/// Return a replacement payload op for the given op, which is going to be
@@ -1015,6 +1022,10 @@ class TrackingListener : public RewriterBase::Listener,
10151022

10161023
/// The handles that are consumed by the transform op.
10171024
DenseSet<Value> consumedHandles;
1025+
1026+
/// Handles for which this function evaluates to "true" do not have to be
1027+
/// updated. These are typically dead or consumed handles.
1028+
SkipHandleFn skipHandleFn;
10181029
};
10191030

10201031
/// A specialized listener that keeps track of cases in which no replacement

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@
3030

3131
using namespace mlir;
3232

33+
//===----------------------------------------------------------------------===//
34+
// Helper functions
35+
//===----------------------------------------------------------------------===//
36+
37+
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
38+
/// properly dominates `b` and `b` is not inside `a`.
39+
static bool happensBefore(Operation *a, Operation *b) {
40+
do {
41+
if (a->isProperAncestor(b))
42+
return false;
43+
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
44+
return a->isBeforeInBlock(bAncestor);
45+
}
46+
} while ((a = a->getParentOp()));
47+
return false;
48+
}
49+
3350
//===----------------------------------------------------------------------===//
3451
// TransformState
3552
//===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
4461
topLevelMappedValues.reserve(extraMappings.size());
4562
for (ArrayRef<MappedValue> mapping : extraMappings)
4663
topLevelMappedValues.push_back(mapping);
47-
48-
auto result =
49-
mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
50-
assert(result.second && "the region scope is already present");
51-
(void)result;
52-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
53-
regionStack.push_back(region);
54-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
64+
if (region) {
65+
RegionScope *scope = new RegionScope(*this, *region);
66+
topLevelRegionScope.reset(scope);
67+
}
5568
}
5669

5770
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
811824
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
812825
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
813826
});
827+
828+
// Set current transform op.
829+
regionStack.back()->currentTransform = transform;
830+
831+
// Expensive checks to detect invalid transform IR.
814832
if (options.getExpensiveChecksEnabled()) {
815833
FULL_LDBG("ExpensiveChecksEnabled\n");
816834
if (failed(checkAndRecordHandleInvalidation(transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
899917
}
900918

901919
// Prepare rewriter and listener.
902-
transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
920+
TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
921+
// Skip handle if it is dead.
922+
auto scopeIt =
923+
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
924+
return handle.getParentRegion() == scope->region;
925+
});
926+
assert(scopeIt != regionStack.rend() &&
927+
"could not find region scope for handle");
928+
RegionScope *scope = *scopeIt;
929+
for (Operation *user : handle.getUsers()) {
930+
if (user != scope->currentTransform &&
931+
!happensBefore(user, scope->currentTransform))
932+
return false;
933+
}
934+
return true;
935+
};
936+
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
937+
skipHandleFn);
903938
transform::TransformRewriter rewriter(transform->getContext(),
904939
&trackingListener);
905940

@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
10401075
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10411076

10421077
state.mappings.erase(region);
1043-
1044-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
10451078
state.regionStack.pop_back();
1046-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10471079
}
10481080

10491081
//===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
11501182
//===----------------------------------------------------------------------===//
11511183

11521184
transform::TrackingListener::TrackingListener(TransformState &state,
1153-
TransformOpInterface op)
1154-
: TransformState::Extension(state), transformOp(op) {
1185+
TransformOpInterface op,
1186+
SkipHandleFn skipHandleFn)
1187+
: TransformState::Extension(state), transformOp(op),
1188+
skipHandleFn(skipHandleFn) {
11551189
if (op) {
11561190
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
11571191
consumedHandles.insert(opOperand->get());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
12511285
});
12521286
}
12531287

1254-
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
1255-
/// properly dominates `b` and `b` is not inside `a`.
1256-
static bool happensBefore(Operation *a, Operation *b) {
1257-
do {
1258-
if (a->isProperAncestor(b))
1259-
return false;
1260-
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
1261-
return a->isBeforeInBlock(bAncestor);
1262-
}
1263-
} while ((a = a->getParentOp()));
1264-
return false;
1265-
}
1266-
12671288
void transform::TrackingListener::notifyOperationReplaced(
12681289
Operation *op, ValueRange newValues) {
12691290
assert(op->getNumResults() == newValues.size() &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
12951316
[&](Value h) { return consumedHandles.contains(h); });
12961317
};
12971318

1298-
// Helper function to check if the handle is alive.
1299-
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
1300-
for (Value v : opHandles) {
1301-
for (OpOperand &use : v.getUses())
1302-
if (use.getOwner() != transformOp &&
1303-
!happensBefore(use.getOwner(), transformOp))
1304-
return &use;
1305-
}
1306-
return std::nullopt;
1307-
}();
1308-
1309-
if (!firstAliveUser.has_value() || handleWasConsumed()) {
1319+
// Check if there are any handles that must be updated.
1320+
Value aliveHandle;
1321+
if (skipHandleFn) {
1322+
auto it =
1323+
llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
1324+
if (it != opHandles.end())
1325+
aliveHandle = *it;
1326+
} else if (!opHandles.empty()) {
1327+
aliveHandle = opHandles.front();
1328+
}
1329+
if (!aliveHandle || handleWasConsumed()) {
13101330
// The op is tracked but the corresponding handles are dead or were
13111331
// consumed. Drop the op form the mapping.
13121332
(void)replacePayloadOp(op, nullptr);
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
13191339
// If the op is tracked but no replacement op was found, send a
13201340
// notification.
13211341
if (!diag.succeeded()) {
1322-
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
1323-
<< "replacement is required because alive handle(s) exist "
1324-
<< "(first use in this op as operand number "
1325-
<< (*firstAliveUser)->getOperandNumber() << ")";
1342+
diag.attachNote(aliveHandle.getLoc())
1343+
<< "replacement is required because this handle must be updated";
13261344
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
13271345
(void)replacePayloadOp(op, nullptr);
13281346
return;

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
3636
transform.sequence failures(propagate) {
3737
^bb1(%arg1: !transform.any_op):
3838
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
39+
// expected-note @below {{replacement is required because this handle must be updated}}
3940
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
4041
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
4142
// expected-note @below {{ran out of suitable replacement values}}
@@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
4445
} : !transform.any_op
4546
// %1 must be used in some way. If no replacement payload op could be found,
4647
// an error is thrown only if the handle is not dead.
47-
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
4848
transform.annotate %1 "annotated" : !transform.any_op
4949
}
5050

@@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
363363
legal_ops = ["func.func", "func.return", "test.new_op"]}
364364
: !transform.any_op
365365
}
366+
367+
// -----
368+
369+
module attributes { transform.with_named_sequence } {
370+
func.func @replacement_op_not_found() {
371+
// No op replacement can be found, but there are no handles that must be
372+
// updated. No error should be reported.
373+
"test.container"() ({
374+
%0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
375+
}) : () -> ()
376+
return
377+
}
378+
379+
transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
380+
transform.apply_patterns to %container {
381+
transform.apply_patterns.transform.test_patterns
382+
} : !transform.any_op
383+
transform.yield
384+
}
385+
386+
transform.sequence failures(propagate) {
387+
^bb1(%arg1: !transform.any_op):
388+
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
389+
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
390+
transform.annotate %1 "annotated" : !transform.any_op
391+
transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
392+
}
393+
}

0 commit comments

Comments
 (0)