Skip to content

Commit ccfc2d6

Browse files
[mlir][transform] Remove cachedNames expensive check (#73961)
This check was trying to find cases of invalid API usage: incorrect/missing handle side effects and/or incorrect rewriter usage. This check is not implemented correctly and can report false positives in case of pointer reuse (different op created at same location). It is unclear if such a check can be implemented given that we have both tracking listener-based handle updates and handle consumption. Fixes #72931.
1 parent fc74db4 commit ccfc2d6

File tree

2 files changed

+3
-133
lines changed

2 files changed

+3
-133
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -789,18 +789,6 @@ class TransformState {
789789
/// Each region must be an ancestor of the following regions in this list.
790790
/// These are also the keys for "mappings".
791791
SmallVector<Region *> regionStack;
792-
793-
/// This cache stores operation names for operations that are tracked in the
794-
/// transform dialect state. It is used to detect missing memory side effects
795-
/// and op tracking.
796-
///
797-
/// All tracked ops are added to this cache before a transform op is applied.
798-
/// After the application of the transform op, the names of all tracked ops
799-
/// are compared with the names in the cache. If there is a mismatch (or a
800-
/// crash), op tracking is missing somewhere. This is typically a missing
801-
/// "consumesHandle" side effect or a pattern that removes an op without
802-
/// notifying a TrackingListener.
803-
DenseMap<Operation *, OperationName> cachedNames;
804792
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
805793
};
806794

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -386,23 +386,6 @@ transform::TransformState::replacePayloadOp(Operation *op,
386386
dropMappingEntry(mappings.reverse, op, handle);
387387
}
388388

389-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
390-
if (options.getExpensiveChecksEnabled()) {
391-
auto it = cachedNames.find(op);
392-
assert(it != cachedNames.end() && "entry not found");
393-
assert(it->second == op->getName() && "operation name mismatch");
394-
cachedNames.erase(it);
395-
if (replacement) {
396-
auto insertion =
397-
cachedNames.insert({replacement, replacement->getName()});
398-
if (!insertion.second) {
399-
assert(insertion.first->second == replacement->getName() &&
400-
"operation is already cached with a different name");
401-
}
402-
}
403-
}
404-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
405-
406389
// Replace the pointed-to object of all handles with the replacement object.
407390
// In case a payload op was erased (replacement object is nullptr), a nullptr
408391
// is stored in the mapping. These nullptrs are removed after each transform.
@@ -494,10 +477,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
494477
unsigned operandNo = consumingHandle.getOperandNumber();
495478
for (Operation *ancestor : potentialAncestors) {
496479
// clang-format off
497-
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
480+
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
498481
{ (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
499-
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
500-
{ (DBGS() << "----of payload with name: "
482+
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
483+
{ (DBGS() << "----of payload with name: "
501484
<< payloadOp->getName().getIdentifier() << "\n"); });
502485
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
503486
{ (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -872,29 +855,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
872855
FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
873856
}
874857
}
875-
876-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
877-
// Cache Operation* -> OperationName mappings. These will be checked after
878-
// the transform has been applied to detect incorrect memory side effects
879-
// and missing op tracking.
880-
for (std::unique_ptr<Mappings> &mapping :
881-
llvm::make_second_range(mappings)) {
882-
for (Operation *op : llvm::make_first_range(mapping->reverse)) {
883-
auto insertion = cachedNames.insert({op, op->getName()});
884-
if (!insertion.second) {
885-
if (insertion.first->second != op->getName()) {
886-
// Operation is already in the cache, but with a different name.
887-
DiagnosedDefiniteFailure diag =
888-
emitDefiniteFailure(transform->getLoc())
889-
<< "expensive checks failure: operation mismatch, expected "
890-
<< insertion.first->second;
891-
diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
892-
return diag;
893-
}
894-
}
895-
}
896-
}
897-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
898858
}
899859

900860
// Find which operands are consumed.
@@ -908,22 +868,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
908868
// IR after that.
909869
SmallVector<Value> origOpFlatResults;
910870
SmallVector<Operation *> origAssociatedOps;
911-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
912-
DenseSet<Operation *> consumedPayloadOps;
913-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
914871
for (OpOperand *opOperand : consumedOperands) {
915872
Value operand = opOperand->get();
916873
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
917874
for (Operation *payloadOp : getPayloadOps(operand)) {
918875
llvm::append_range(origOpFlatResults, payloadOp->getResults());
919-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
920-
if (options.getExpensiveChecksEnabled()) {
921-
// Store all consumed payload ops (and their nested ops) in a set for
922-
// extra error checking.
923-
payloadOp->walk(
924-
[&](Operation *op) { consumedPayloadOps.insert(op); });
925-
}
926-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
927876
}
928877
continue;
929878
}
@@ -1004,63 +953,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
1004953
}
1005954
}
1006955

1007-
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1008-
if (options.getExpensiveChecksEnabled()) {
1009-
// Remove erased ops from the transform state.
1010-
for (Operation *op : consumedPayloadOps) {
1011-
// This payload op was consumed but it may still be mapped to one or
1012-
// multiple handles. Forget all handles that are mapped to the op, so that
1013-
// there are no dangling pointers in the transform dialect state. This is
1014-
// necessary so that the `cachedNames`-based checks work correctly.
1015-
//
1016-
// Note: Dangling pointers to erased payload ops are allowed if the
1017-
// corresponding handles are not used anymore. There is another
1018-
// "expensive-check" that looks for future uses of dangling payload op
1019-
// pointers (through arbitrary handles). Removing handles to erased ops
1020-
// does not interfere with the other expensive checks: handle invalidation
1021-
// happens earlier and keeps track of invalidated handles with
1022-
// pre-generated error messages, so we do not need the association to
1023-
// still be there when the invalidated handle is accessed.
1024-
SmallVector<Value> handles;
1025-
(void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
1026-
for (Value handle : handles)
1027-
forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
1028-
/*allowOutOfScope=*/true);
1029-
cachedNames.erase(op);
1030-
}
1031-
1032-
// Check cached operation names.
1033-
for (std::unique_ptr<Mappings> &mapping :
1034-
llvm::make_second_range(mappings)) {
1035-
for (Operation *op : llvm::make_first_range(mapping->reverse)) {
1036-
// Make sure that the name of the op has not changed. If it has changed,
1037-
// the op was removed and a new op was allocated at the same memory
1038-
// location. This means that we are missing op tracking somewhere.
1039-
auto cacheIt = cachedNames.find(op);
1040-
if (cacheIt == cachedNames.end()) {
1041-
DiagnosedDefiniteFailure diag =
1042-
emitDefiniteFailure(transform->getLoc())
1043-
<< "expensive checks failure: operation not found in cache";
1044-
diag.attachNote(op->getLoc()) << "payload op";
1045-
return diag;
1046-
}
1047-
// If the `getName` call (or the above `attachNote`) is crashing, we
1048-
// have a dangling pointer. This usually means that an op was erased but
1049-
// the transform dialect was not made aware of that; e.g., missing
1050-
// "consumesHandle" or rewriter usage.
1051-
if (cacheIt->second != op->getName()) {
1052-
DiagnosedDefiniteFailure diag =
1053-
emitDefiniteFailure(transform->getLoc())
1054-
<< "expensive checks failure: operation mismatch, expected "
1055-
<< cacheIt->second;
1056-
diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
1057-
return diag;
1058-
}
1059-
}
1060-
}
1061-
}
1062-
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1063-
1064956
if (failed(updateStateFromResults(results, transform->getResults())))
1065957
return DiagnosedSilenceableFailure::definiteFailure();
1066958

@@ -1150,16 +1042,6 @@ transform::TransformState::RegionScope::~RegionScope() {
11501042
state.mappings.erase(region);
11511043

11521044
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
1153-
// If the last handle to a payload op has gone out of scope, we no longer
1154-
// need to store the cached name. Pointers may get reused, leading to
1155-
// incorrect associations in the cache.
1156-
for (Operation *op : referencedOps) {
1157-
SmallVector<Value> handles;
1158-
if (succeeded(state.getHandlesForPayloadOp(op, handles)))
1159-
continue;
1160-
state.cachedNames.erase(op);
1161-
}
1162-
11631045
state.regionStack.pop_back();
11641046
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
11651047
}

0 commit comments

Comments
 (0)