Skip to content

[mlir][transform] Fix and improve "cached names" check #73583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 27, 2023

When running with "expensive checks", the transform dialect interpreter
maintains a payload Operation * -> OperationName cache. This cache
is used to detect invalid API usage such as missing/incorrect handle
consumption/production side effects and/or payload IR modifications that
bypass the rewriter.

There was a bug in the check that can cause issues such as #72931.
(#72986 was just a workaround and did not really fix the underlying
issue.)

  • Payload ops mapped to newly produced handles are now added to the
    cache. This is in addition to adding/checking all mapped payload ops
    at the beginning of each transform op, for extra safety.
  • Remove consumed ops (and their children) before applying the transform
    op. This used to happen after applying the transform op, which is
    incorrect in cases such as: (1) transform op replaces a consumed
    payload op with another op, (2) the new op reuses the same memory
    pointer and (3) the new op is added to a newly produced handle. In
    such a case the previous implementation removed the newly created op
    from the cache.
  • No assumptions can be made about whether an op should be in the cache
    or not. The code previously reported an error when an op was not found
    in the cache. E.g., this is problematic in cases such as: (1) the
    transform op consumes the handle mapped to a payload op A and (2) the
    implementation of the payload op removes/replaces a nested op with A,
    which is mapped to another handle. This triggers a listener
    notification, which removes the nested op from the cache. However,
    because consumed ops (and their children) are removed from the cache
    before applying the transform op, the nested op will not be in cache
    and making such an assumption would be incorrect.

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2023

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

When running with "expensive checks", the transform dialect interpreter
maintains a payload Operation * -> OperationName cache. This cache
is used to detect invalid API usage such as missing/incorrect handle
consumption/production side effects and/or payload IR modifications that
bypass the rewriter.

There was a bug in the check that can cause issues such as #72931.
(#72986 was just a workaround and did not really fix the underlying
issue.)

  • Payload ops mapped to newly produced handles are now added to the
    cache. This is in addition to adding/checking all mapped payload ops
    at the beginning of each transform op, for extra safety.
  • Remove consumed ops (and their children) before applying the transform
    op. This used to happen after applying the transform op, which is
    incorrect in cases such as: (1) transform op replaces a consumed
    payload op with another op, (2) the new op reuses the same memory
    pointer and (3) the new op is added to a newly produced handle. In
    such a case the previous implementation removed the newly created op
    from the cache.
  • No assumptions can be made about whether an op should be in the cache
    or not. The code previously reported an error when an op was not found
    in the cache. E.g., this is problematic in cases such as: (1) the
    transform op consumes the handle mapped to a payload op A and (2) the
    implementation of the payload op removes/replaces a nested op with A,
    which is mapped to another handle. This triggers a listener
    notification, which removes the nested op from the cache. However,
    because consumed ops (and their children) are removed from the cache
    before applying the transform op, the nested op will not be in cache
    and making such an assumption would be incorrect.

Full diff: https://github.com/llvm/llvm-project/pull/73583.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+107-103)
  • (modified) mlir/test/Dialect/Transform/transform-state-extension.mlir (+3)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d0cd879d560c887..15d8a7e34c94032 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -30,6 +30,33 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
+/// Return nullptr if `v` is dead (has no further uses) after `op`. Otherwise,
+/// return an arbitrary alive use. This return value is typically used in error
+/// messages or for debugging purposes.
+static OpOperand *getAliveUse(Value v, Operation *op) {
+  for (OpOperand &use : v.getUses())
+    if (use.getOwner() != op && !happensBefore(use.getOwner(), op))
+      return &use;
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // TransformState
 //===----------------------------------------------------------------------===//
@@ -216,6 +243,10 @@ transform::TransformState::setPayloadOps(Value value,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (value.getUses().empty())
+    return success();
+
   // Setting new payload for the value without cleaning it first is a misuse of
   // the API, assert here.
   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
@@ -228,6 +259,22 @@ transform::TransformState::setPayloadOps(Value value,
   for (Operation *op : targets)
     mappings.reverse[op].push_back(value);
 
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  if (options.getExpensiveChecksEnabled()) {
+    for (Operation *op : targets) {
+      auto insertion = cachedNames.insert({op, op->getName()});
+      if (!insertion.second) {
+        if (insertion.first->second != op->getName()) {
+          // Operation is already in the cache, but with a different name.
+          return emitError(value.getLoc())
+                 << "expensive checks failure: operation mismatch, expected "
+                 << insertion.first->second;
+        }
+      }
+    }
+  }
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+
   return success();
 }
 
@@ -252,6 +299,10 @@ transform::TransformState::setPayloadValues(Value handle,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (handle.getUses().empty())
+    return success();
+
   Mappings &mappings = getMapping(handle);
   bool inserted =
       mappings.values.insert({handle, std::move(payloadValueVector)}).second;
@@ -285,6 +336,10 @@ LogicalResult transform::TransformState::setParams(Value value,
   if (failed(result.checkAndReport()))
     return failure();
 
+  // Do not maintain mappings for dead handles.
+  if (value.getUses().empty())
+    return success();
+
   Mappings &mappings = getMapping(value);
   bool inserted =
       mappings.params.insert({value, llvm::to_vector(params)}).second;
@@ -389,15 +444,20 @@ transform::TransformState::replacePayloadOp(Operation *op,
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   if (options.getExpensiveChecksEnabled()) {
     auto it = cachedNames.find(op);
-    assert(it != cachedNames.end() && "entry not found");
-    assert(it->second == op->getName() && "operation name mismatch");
-    cachedNames.erase(it);
-    if (replacement) {
-      auto insertion =
-          cachedNames.insert({replacement, replacement->getName()});
-      if (!insertion.second) {
-        assert(insertion.first->second == replacement->getName() &&
-               "operation is already cached with a different name");
+    // Payload ops (and their children) mapped to consumed handles were already
+    // removed from the cache. We can make no assumption about which ops are in
+    // the cache and which are not. But if an op is in the cache, the name must
+    // match.
+    if (it != cachedNames.end()) {
+      assert(it->second == op->getName() && "operation name mismatch");
+      cachedNames.erase(it);
+      if (replacement) {
+        auto insertion =
+            cachedNames.insert({replacement, replacement->getName()});
+        if (!insertion.second) {
+          assert(insertion.first->second == replacement->getName() &&
+                 "operation is already cached with a different name");
+        }
       }
     }
   }
@@ -494,10 +554,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
   unsigned operandNo = consumingHandle.getOperandNumber();
   for (Operation *ancestor : potentialAncestors) {
     // clang-format off
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
-    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
-      { (DBGS() << "----of payload with name: " 
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+      { (DBGS() << "----of payload with name: "
                 << payloadOp->getName().getIdentifier() << "\n"); });
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
       { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -908,9 +968,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   // IR after that.
   SmallVector<Value> origOpFlatResults;
   SmallVector<Operation *> origAssociatedOps;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  DenseSet<Operation *> consumedPayloadOps;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   for (OpOperand *opOperand : consumedOperands) {
     Value operand = opOperand->get();
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
@@ -918,10 +975,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
         llvm::append_range(origOpFlatResults, payloadOp->getResults());
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
         if (options.getExpensiveChecksEnabled()) {
-          // Store all consumed payload ops (and their nested ops) in a set for
-          // extra error checking.
-          payloadOp->walk(
-              [&](Operation *op) { consumedPayloadOps.insert(op); });
+          // Remove all consumed payload ops (and their nested ops) from the
+          // name cache.
+          payloadOp->walk([&](Operation *op) { cachedNames.erase(op); });
         }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
       }
@@ -992,58 +1048,41 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   if (result.isSilenceableFailure())
     results.setRemainingToEmpty(transform);
 
-  // Remove the mapping for the operand if it is consumed by the operation. This
-  // allows us to catch use-after-free with assertions later on.
-  for (OpOperand *opOperand : consumedOperands) {
-    Value operand = opOperand->get();
+  // Remove the mapping for the operand if it is consumed by the operation. Also
+  // remove the mapping for handles that are now dead. This allows us to catch
+  // use-after-free with assertions later on.
+  for (OpOperand &opOperand : transform->getOpOperands()) {
+    Value operand = opOperand.get();
+    if (getAliveUse(operand, transform) != nullptr)
+      continue;
+    bool wasConsumed = llvm::is_contained(consumedOperands, &opOperand);
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
-      forgetMapping(operand, origOpFlatResults);
+      forgetMapping(operand,
+                    wasConsumed ? ValueRange(origOpFlatResults) : ValueRange());
     } else if (llvm::isa<TransformValueHandleTypeInterface>(
                    operand.getType())) {
-      forgetValueMapping(operand, origAssociatedOps);
+      forgetValueMapping(operand, wasConsumed ? ArrayRef(origAssociatedOps)
+                                              : ArrayRef<Operation *>());
     }
   }
 
+  if (failed(updateStateFromResults(results, transform->getResults())))
+    return DiagnosedSilenceableFailure::definiteFailure();
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   if (options.getExpensiveChecksEnabled()) {
-    // Remove erased ops from the transform state.
-    for (Operation *op : consumedPayloadOps) {
-      // This payload op was consumed but it may still be mapped to one or
-      // multiple handles. Forget all handles that are mapped to the op, so that
-      // there are no dangling pointers in the transform dialect state. This is
-      // necessary so that the `cachedNames`-based checks work correctly.
-      //
-      // Note: Dangling pointers to erased payload ops are allowed if the
-      // corresponding handles are not used anymore. There is another
-      // "expensive-check" that looks for future uses of dangling payload op
-      // pointers (through arbitrary handles). Removing handles to erased ops
-      // does not interfere with the other expensive checks: handle invalidation
-      // happens earlier and keeps track of invalidated handles with
-      // pre-generated error messages, so we do not need the association to
-      // still be there when the invalidated handle is accessed.
-      SmallVector<Value> handles;
-      (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
-      for (Value handle : handles)
-        forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
-                      /*allowOutOfScope=*/true);
-      cachedNames.erase(op);
-    }
-
     // Check cached operation names.
     for (std::unique_ptr<Mappings> &mapping :
          llvm::make_second_range(mappings)) {
       for (Operation *op : llvm::make_first_range(mapping->reverse)) {
         // Make sure that the name of the op has not changed. If it has changed,
         // the op was removed and a new op was allocated at the same memory
-        // location. This means that we are missing op tracking somewhere.
+        // location. This means that we are missing op tracking somewhere. We
+        // can make no assumption about which ops are in the cache and which are
+        // not. But if an op is in the cache, the name must match.
         auto cacheIt = cachedNames.find(op);
-        if (cacheIt == cachedNames.end()) {
-          DiagnosedDefiniteFailure diag =
-              emitDefiniteFailure(transform->getLoc())
-              << "expensive checks failure: operation not found in cache";
-          diag.attachNote(op->getLoc()) << "payload op";
-          return diag;
-        }
+        if (cacheIt == cachedNames.end())
+          continue;
         // If the `getName` call (or the above `attachNote`) is crashing, we
         // have a dangling pointer. This usually means that an op was erased but
         // the transform dialect was not made aware of that; e.g., missing
@@ -1061,9 +1100,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
-  if (failed(updateStateFromResults(results, transform->getResults())))
-    return DiagnosedSilenceableFailure::definiteFailure();
-
   printOnFailureRAII.release();
   DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
     DBGS() << "Top-level payload:\n";
@@ -1140,26 +1176,9 @@ transform::TransformState::RegionScope::~RegionScope() {
     }
   }
 
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  // Remember pointers to payload ops referenced by the handles going out of
-  // scope.
-  SmallVector<Operation *> referencedOps =
-      llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
   state.mappings.erase(region);
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  // If the last handle to a payload op has gone out of scope, we no longer
-  // need to store the cached name. Pointers may get reused, leading to
-  // incorrect associations in the cache.
-  for (Operation *op : referencedOps) {
-    SmallVector<Value> handles;
-    if (succeeded(state.getHandlesForPayloadOp(op, handles)))
-      continue;
-    state.cachedNames.erase(op);
-  }
-
   state.regionStack.pop_back();
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }
@@ -1369,19 +1388,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   });
 }
 
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
-  do {
-    if (a->isProperAncestor(b))
-      return false;
-    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
-      return a->isBeforeInBlock(bAncestor);
-    }
-  } while ((a = a->getParentOp()));
-  return false;
-}
-
 void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
@@ -1413,20 +1419,18 @@ void transform::TrackingListener::notifyOperationReplaced(
                         [&](Value h) { return consumedHandles.contains(h); });
   };
 
-  // Helper function to check if the handle is alive.
-  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
-    for (Value v : opHandles) {
-      for (OpOperand &use : v.getUses())
-        if (use.getOwner() != transformOp &&
-            !happensBefore(use.getOwner(), transformOp))
-          return &use;
+  // Check if there are any live handles.
+  OpOperand *aliveUse = nullptr;
+  for (Value v : opHandles) {
+    if (OpOperand *use = getAliveUse(v, transformOp)) {
+      aliveUse = use;
+      break;
     }
-    return std::nullopt;
-  }();
+  }
 
-  if (!firstAliveUser.has_value() || handleWasConsumed()) {
-    // The op is tracked but the corresponding handles are dead or were
-    // consumed. Drop the op form the mapping.
+  if (!aliveUse || handleWasConsumed()) {
+    // The op is tracked but the corresponding handles are dead. Drop the op
+    // from the mapping.
     (void)replacePayloadOp(op, nullptr);
     return;
   }
@@ -1437,10 +1441,10 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (!diag.succeeded()) {
-    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
+    diag.attachNote(aliveUse->getOwner()->getLoc())
         << "replacement is required because alive handle(s) exist "
         << "(first use in this op as operand number "
-        << (*firstAliveUser)->getOperandNumber() << ")";
+        << aliveUse->getOperandNumber() << ")";
     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
     (void)replacePayloadOp(op, nullptr);
     return;
diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index a26293fbe51ca61..cd115027d0f0002 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -76,6 +76,9 @@ transform.sequence failures(propagate) {
   %dummy = test_remap_operand_to_self %arg0 : (!transform.any_op) -> !transform.any_op
   %valuehandle = transform.get_result %dummy[0] : (!transform.any_op) -> !transform.any_value
   test_remap_operand_to_self %dummy : (!transform.any_op) -> ()
+  // Use %valuehandle so that the SSA value is not dead. This prevents the
+  // transform dialect interpreter from discarding the handle.
+  test_print_number_of_associated_payload_ir_values %valuehandle : !transform.any_value
 }
 
 // -----

When running with "expensive checks", the transform dialect interpreter
maintains a payload `Operation *` -> `OperationName` cache. This cache
is used to detect invalid API usage such as missing/incorrect handle
consumption/production side effects and/or payload IR modifications that
bypass the rewriter.

There was a bug in the check that can cause issues such as llvm#72931.
(llvm#72986 was just a workaround and did not really fix the underlying
issue.)

- Payload ops mapped to newly produced handles are now added to the
  cache. This is in addition to adding/checking all mapped payload ops
  at the beginning of each transform op, for extra safety.
- Remove consumed ops (and their children) before applying the transform
  op. This used to happen after applying the transform op, which is
  incorrect in cases such as: (1) transform op replaces a consumed
  payload op with another op, (2) the new op reuses the same memory
  pointer and (3) the new op is added to a newly produced handle. In
  such a case the previous implementation removed the newly created op
  from the cache.
- No assumptions can be made about whether an op should be in the cache
  or not. The code previously reported an error when an op was not found
  in the cache. E.g., this is problematic in cases such as: (1) the
  transform op consumes the handle mapped to a payload op A and (2) the
  implementation of the payload op removes/replaces a nested op with A,
  which is mapped to another handle. This triggers a listener
  notification, which removes the nested op from the cache. However,
  because consumed ops (and their children) are removed from the cache
  before applying the transform op, the nested op will not be in cache
  and making such an assumption would be incorrect.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants