Skip to content

Commit 4730954

Browse files
[mlir][transform] Fix and improve "cached names" check
When running with "expensive checks", the transform dialect interpreter maintains a payload `Operation *` -> `OperationName` cache. This cache is used to detect invalid API usage such as missing/incorrect handle consumption/production side effects and/or payload IR modifications that bypass the rewriter. There was a bug in the check that can cause issues such as #72931. (#72986 was just a workaround and did not really fix the underlying issue.) - Payload ops mapped to newly produced handles are now added to the cache. This is in addition to adding/checking all mapped payload ops at the beginning of each transform op, for extra safety. - Remove consumed ops (and their children) before applying the transform op. This used to happen after applying the transform op, which is incorrect in cases such as: (1) transform op replaces a consumed payload op with another op, (2) the new op reuses the same memory pointer and (3) the new op is added to a newly produced handle. In such a case the previous implementation removed the newly created op from the cache. - No assumptions can be made about whether an op should be in the cache or not. The code previously reported an error when an op was not found in the cache. E.g., this is problematic in cases such as: (1) the transform op consumes the handle mapped to a payload op A and (2) the implementation of the payload op removes/replaces a nested op with A, which is mapped to another handle. This triggers a listener notification, which removes the nested op from the cache. However, because consumed ops (and their children) are removed from the cache before applying the transform op, the nested op will not be in cache and making such an assumption would be incorrect.
1 parent 56c72c7 commit 4730954

File tree

1 file changed

+41
-67
lines changed

1 file changed

+41
-67
lines changed

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 41 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,22 @@ transform::TransformState::setPayloadOps(Value value,
228228
for (Operation *op : targets)
229229
mappings.reverse[op].push_back(value);
230230

231+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
232+
if (options.getExpensiveChecksEnabled()) {
233+
for (Operation *op : targets) {
234+
auto insertion = cachedNames.insert({op, op->getName()});
235+
if (!insertion.second) {
236+
if (insertion.first->second != op->getName()) {
237+
// Operation is already in the cache, but with a different name.
238+
return emitError(value.getLoc())
239+
<< "expensive checks failure: operation mismatch, expected "
240+
<< insertion.first->second;
241+
}
242+
}
243+
}
244+
}
245+
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
246+
231247
return success();
232248
}
233249

@@ -389,15 +405,20 @@ transform::TransformState::replacePayloadOp(Operation *op,
389405
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
390406
if (options.getExpensiveChecksEnabled()) {
391407
auto it = cachedNames.find(op);
392-
assert(it != cachedNames.end() && "entry not found");
393-
assert(it->second == op->getName() && "operation name mismatch");
394-
cachedNames.erase(it);
395-
if (replacement) {
396-
auto insertion =
397-
cachedNames.insert({replacement, replacement->getName()});
398-
if (!insertion.second) {
399-
assert(insertion.first->second == replacement->getName() &&
400-
"operation is already cached with a different name");
408+
// Payload ops (and their children) mapped to consumed handles were already
409+
// removed from the cache. We can make no assumption about which ops are in
410+
// the cache and which are not. But if an op is in the cache, the name must
411+
// match.
412+
if (it != cachedNames.end()) {
413+
assert(it->second == op->getName() && "operation name mismatch");
414+
cachedNames.erase(it);
415+
if (replacement) {
416+
auto insertion =
417+
cachedNames.insert({replacement, replacement->getName()});
418+
if (!insertion.second) {
419+
assert(insertion.first->second == replacement->getName() &&
420+
"operation is already cached with a different name");
421+
}
401422
}
402423
}
403424
}
@@ -908,20 +929,16 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
908929
// IR after that.
909930
SmallVector<Value> origOpFlatResults;
910931
SmallVector<Operation *> origAssociatedOps;
911-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
912-
DenseSet<Operation *> consumedPayloadOps;
913-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
914932
for (OpOperand *opOperand : consumedOperands) {
915933
Value operand = opOperand->get();
916934
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
917935
for (Operation *payloadOp : getPayloadOps(operand)) {
918936
llvm::append_range(origOpFlatResults, payloadOp->getResults());
919937
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
920938
if (options.getExpensiveChecksEnabled()) {
921-
// Store all consumed payload ops (and their nested ops) in a set for
922-
// extra error checking.
923-
payloadOp->walk(
924-
[&](Operation *op) { consumedPayloadOps.insert(op); });
939+
// Remove all consumed payload ops (and their nested ops) from the
940+
// name cache.
941+
payloadOp->walk([&](Operation *op) { cachedNames.erase(op); });
925942
}
926943
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
927944
}
@@ -1004,46 +1021,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
10041021
}
10051022
}
10061023

1024+
if (failed(updateStateFromResults(results, transform->getResults())))
1025+
return DiagnosedSilenceableFailure::definiteFailure();
1026+
10071027
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
10081028
if (options.getExpensiveChecksEnabled()) {
1009-
// Remove erased ops from the transform state.
1010-
for (Operation *op : consumedPayloadOps) {
1011-
// This payload op was consumed but it may still be mapped to one or
1012-
// multiple handles. Forget all handles that are mapped to the op, so that
1013-
// there are no dangling pointers in the transform dialect state. This is
1014-
// necessary so that the `cachedNames`-based checks work correctly.
1015-
//
1016-
// Note: Dangling pointers to erased payload ops are allowed if the
1017-
// corresponding handles are not used anymore. There is another
1018-
// "expensive-check" that looks for future uses of dangling payload op
1019-
// pointers (through arbitrary handles). Removing handles to erased ops
1020-
// does not interfere with the other expensive checks: handle invalidation
1021-
// happens earlier and keeps track of invalidated handles with
1022-
// pre-generated error messages, so we do not need the association to
1023-
// still be there when the invalidated handle is accessed.
1024-
SmallVector<Value> handles;
1025-
(void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
1026-
for (Value handle : handles)
1027-
forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
1028-
/*allowOutOfScope=*/true);
1029-
cachedNames.erase(op);
1030-
}
1031-
10321029
// Check cached operation names.
10331030
for (std::unique_ptr<Mappings> &mapping :
10341031
llvm::make_second_range(mappings)) {
10351032
for (Operation *op : llvm::make_first_range(mapping->reverse)) {
10361033
// Make sure that the name of the op has not changed. If it has changed,
10371034
// the op was removed and a new op was allocated at the same memory
1038-
// location. This means that we are missing op tracking somewhere.
1035+
// location. This means that we are missing op tracking somewhere. We
1036+
// can make no assumption about which ops are in the cache and which are
1037+
// not. But if an op is in the cache, the name must match.
10391038
auto cacheIt = cachedNames.find(op);
1040-
if (cacheIt == cachedNames.end()) {
1041-
DiagnosedDefiniteFailure diag =
1042-
emitDefiniteFailure(transform->getLoc())
1043-
<< "expensive checks failure: operation not found in cache";
1044-
diag.attachNote(op->getLoc()) << "payload op";
1045-
return diag;
1046-
}
1039+
if (cacheIt == cachedNames.end())
1040+
continue;
10471041
// If the `getName` call (or the above `attachNote`) is crashing, we
10481042
// have a dangling pointer. This usually means that an op was erased but
10491043
// the transform dialect was not made aware of that; e.g., missing
@@ -1061,9 +1055,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
10611055
}
10621056
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10631057

1064-
if (failed(updateStateFromResults(results, transform->getResults())))
1065-
return DiagnosedSilenceableFailure::definiteFailure();
1066-
10671058
printOnFailureRAII.release();
10681059
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
10691060
DBGS() << "Top-level payload:\n";
@@ -1140,26 +1131,9 @@ transform::TransformState::RegionScope::~RegionScope() {
11401131
}
11411132
}
11421133

1143-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1144-
// Remember pointers to payload ops referenced by the handles going out of
1145-
// scope.
1146-
SmallVector<Operation *> referencedOps =
1147-
llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1148-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1149-
11501134
state.mappings.erase(region);
11511135

11521136
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1153-
// If the last handle to a payload op has gone out of scope, we no longer
1154-
// need to store the cached name. Pointers may get reused, leading to
1155-
// incorrect associations in the cache.
1156-
for (Operation *op : referencedOps) {
1157-
SmallVector<Value> handles;
1158-
if (succeeded(state.getHandlesForPayloadOp(op, handles)))
1159-
continue;
1160-
state.cachedNames.erase(op);
1161-
}
1162-
11631137
state.regionStack.pop_back();
11641138
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
11651139
}

0 commit comments

Comments
 (0)