Skip to content

[mlir][transform] TrackingListener: Improve dead handles detection #74290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

The tracking listener should not report op replacement errors for payload ops that are not mapped to any live handles. The handle liveless analysis did not work properly with transform IR that has named sequences.

A handle is live if it has a user after the transform op that is currently being applied. With named sequences, we need to maintain a stack of currently applied transform ops. That stack already exists (regionStack), the only thing that's missing is the current transform op for each stack frame.

This commit fixes #72931.

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The tracking listener should not report op replacement errors for payload ops that are not mapped to any live handles. The handle liveless analysis did not work properly with transform IR that has named sequences.

A handle is live if it has a user after the transform op that is currently being applied. With named sequences, we need to maintain a stack of currently applied transform ops. That stack already exists (regionStack), the only thing that's missing is the current transform op for each stack frame.

This commit fixes #72931.


Full diff: https://github.com/llvm/llvm-project/pull/74290.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+19-11)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+61-43)
  • (modified) mlir/test/Dialect/Transform/test-pattern-application.mlir (+29-1)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2fdc15db9ad85..35de8a2e1fa5f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -310,10 +310,8 @@ class TransformState {
   /// with the type of the handle value.
   LogicalResult mapBlockArguments(BlockArgument argument,
                                   ArrayRef<Operation *> operations) {
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-    assert(argument.getParentRegion() == regionStack.back() &&
+    assert(argument.getParentRegion() == regionStack.back()->region &&
            "mapping block arguments from a region other than the active one");
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     return setPayloadOps(argument, operations);
   }
   LogicalResult mapBlockArgument(BlockArgument argument,
@@ -350,9 +348,7 @@ class TransformState {
           std::make_pair(&region, std::make_unique<Mappings>()));
       assert(res.second && "the region scope is already present");
       (void)res;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-      state.regionStack.push_back(&region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+      state.regionStack.push_back(this);
     }
 
     /// Back-reference to the transform state.
@@ -361,7 +357,9 @@ class TransformState {
     /// The region this scope is associated with.
     Region *region;
 
-    friend RegionScope TransformState::make_region_scope(Region &);
+    TransformOpInterface currentTransform;
+
+    friend class transform::TransformState;
   };
   friend class RegionScope;
 
@@ -784,12 +782,12 @@ class TransformState {
   /// location.
   InvalidatedHandleMap invalidatedHandles;
 
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// A stack of nested regions that are being processed in the transform IR.
   /// Each region must be an ancestor of the following regions in this list.
   /// These are also the keys for "mappings".
-  SmallVector<Region *> regionStack;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  SmallVector<RegionScope *> regionStack;
+
+  std::unique_ptr<RegionScope> topLevelRegionScope;
 };
 
 /// Local mapping between values defined by a specific op implementing the
@@ -926,8 +924,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
 class TrackingListener : public RewriterBase::Listener,
                          public TransformState::Extension {
 public:
+  /// A function that returns "true" for handles that do not have to be updated.
+  using SkipHandleFn = std::function<bool(Value)>;
+
   /// Create a new TrackingListener for usage in the specified transform op.
-  TrackingListener(TransformState &state, TransformOpInterface op);
+  /// Optionally, a function can be specified to identify handles that should
+  /// do not have to be updated.
+  TrackingListener(TransformState &state, TransformOpInterface op,
+                   SkipHandleFn skipHandleFn = nullptr);
 
 protected:
   /// Return a replacement payload op for the given op, which is going to be
@@ -1015,6 +1019,10 @@ class TrackingListener : public RewriterBase::Listener,
 
   /// The handles that are consumed by the transform op.
   DenseSet<Value> consumedHandles;
+
+  /// Handles for which this function evaluates to "true" do not have to be
+  /// updated. These are typically dead or consumed handles.
+  SkipHandleFn skipHandleFn;
 };
 
 /// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index de5b7a81286bc..cd66a0e566f6c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -30,6 +30,23 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
+/// properly dominates `b` and `b` is not inside `a`.
+static bool happensBefore(Operation *a, Operation *b) {
+  do {
+    if (a->isProperAncestor(b))
+      return false;
+    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
+      return a->isBeforeInBlock(bAncestor);
+    }
+  } while ((a = a->getParentOp()));
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // TransformState
 //===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
   topLevelMappedValues.reserve(extraMappings.size());
   for (ArrayRef<MappedValue> mapping : extraMappings)
     topLevelMappedValues.push_back(mapping);
-
-  auto result =
-      mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
-  assert(result.second && "the region scope is already present");
-  (void)result;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  regionStack.push_back(region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+  if (region) {
+    RegionScope *scope = new RegionScope(*this, *region);
+    topLevelRegionScope.reset(scope);
+  }
 }
 
 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
         llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
   });
+
+  // Set current transform op.
+  regionStack.back()->currentTransform = transform;
+
+  // Expensive checks to detect invalid transform IR.
   if (options.getExpensiveChecksEnabled()) {
     FULL_LDBG("ExpensiveChecksEnabled\n");
     if (failed(checkAndRecordHandleInvalidation(transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   // Prepare rewriter and listener.
-  transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
+  TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+    // Skip handle if it is dead.
+    auto scopeIt =
+        llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
+          return handle.getParentRegion() == scope->region;
+        });
+    assert(scopeIt != regionStack.rend() &&
+           "could not find region scope for handle");
+    RegionScope *scope = *scopeIt;
+    for (Operation *user : handle.getUsers()) {
+      if (user != scope->currentTransform &&
+          !happensBefore(user, scope->currentTransform))
+        return false;
+    }
+    return true;
+  };
+  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
+                                                            skipHandleFn);
   transform::TransformRewriter rewriter(transform->getContext(),
                                         &trackingListener);
 
@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   state.mappings.erase(region);
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   state.regionStack.pop_back();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 }
 
 //===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
 //===----------------------------------------------------------------------===//
 
 transform::TrackingListener::TrackingListener(TransformState &state,
-                                              TransformOpInterface op)
-    : TransformState::Extension(state), transformOp(op) {
+                                              TransformOpInterface op,
+                                              SkipHandleFn skipHandleFn)
+    : TransformState::Extension(state), transformOp(op),
+      skipHandleFn(skipHandleFn) {
   if (op) {
     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
       consumedHandles.insert(opOperand->get());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
   });
 }
 
-/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
-/// properly dominates `b` and `b` is not inside `a`.
-static bool happensBefore(Operation *a, Operation *b) {
-  do {
-    if (a->isProperAncestor(b))
-      return false;
-    if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
-      return a->isBeforeInBlock(bAncestor);
-    }
-  } while ((a = a->getParentOp()));
-  return false;
-}
-
 void transform::TrackingListener::notifyOperationReplaced(
     Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
                         [&](Value h) { return consumedHandles.contains(h); });
   };
 
-  // Helper function to check if the handle is alive.
-  auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
-    for (Value v : opHandles) {
-      for (OpOperand &use : v.getUses())
-        if (use.getOwner() != transformOp &&
-            !happensBefore(use.getOwner(), transformOp))
-          return &use;
-    }
-    return std::nullopt;
-  }();
-
-  if (!firstAliveUser.has_value() || handleWasConsumed()) {
+  // Check if there are any handles that must be updated.
+  Value aliveHandle;
+  if (skipHandleFn) {
+    auto it =
+        llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+    if (it != opHandles.end())
+      aliveHandle = *it;
+  } else if (!opHandles.empty()) {
+    aliveHandle = opHandles.front();
+  }
+  if (!aliveHandle || handleWasConsumed()) {
     // The op is tracked but the corresponding handles are dead or were
     // consumed. Drop the op form the mapping.
     (void)replacePayloadOp(op, nullptr);
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
   // If the op is tracked but no replacement op was found, send a
   // notification.
   if (!diag.succeeded()) {
-    diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
-        << "replacement is required because alive handle(s) exist "
-        << "(first use in this op as operand number "
-        << (*firstAliveUser)->getOperandNumber() << ")";
+    diag.attachNote(aliveHandle.getLoc())
+        << "replacement is required because this handle must be updated";
     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
     (void)replacePayloadOp(op, nullptr);
     return;
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 2d57d4aa2547f..2fd47c6bae396 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-note @below {{replacement is required because this handle must be updated}}
   %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
   // expected-note @below {{ran out of suitable replacement values}}
@@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
   } : !transform.any_op
   // %1 must be used in some way. If no replacement payload op could be found,
   // an error is thrown only if the handle is not dead.
-  // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
   transform.annotate %1 "annotated" : !transform.any_op
 }
 
@@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
      legal_ops = ["func.func", "func.return", "test.new_op"]}
       : !transform.any_op
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+func.func @replacement_op_not_found() {
+  // No op replacement can be found, but there are no handles that must be
+  // updated. No error should be reported.
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
+  transform.apply_patterns to %container {
+    transform.apply_patterns.transform.test_patterns
+  } : !transform.any_op
+  transform.yield
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.annotate %1 "annotated" : !transform.any_op
+  transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
+}
+}

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should update the liveness check to be a dataflow analysis (it's an almost canonical one) that we run once before the interpreter starts.

Comment on lines -1324 to -1325
<< "(first use in this op as operand number "
<< (*firstAliveUser)->getOperandNumber() << ")";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance to preserve this? Feels like useful information...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to, but it would make the API quite complex. SkipHandleFn is pretty simple in the current implementation: is the handle needed or not. It would have to return a user, but the name "skip handle function" would not make sense anymore and we would have to mention the concept of liveness in the TrackingListener.

The tracking listener should not report op replacement errors for
payload ops that are not mapped to any live handles. The handle liveless
analysis did not work properly with transform IR that has named
sequences.

A handle is live if it has a user after the transform op that is
currently being applied. With named sequences, we need to maintain a
stack of currently applied transform ops. That stack already exists
(`regionStack`), the only thing that's missing is the current transform
op for each stack frame.

This commit fixes llvm#72931.
@matthias-springer matthias-springer force-pushed the tracking_listener_regions branch from 4ce7606 to 649c9e0 Compare December 6, 2023 07:31
@matthias-springer
Copy link
Member Author

Maybe we should update the liveness check to be a dataflow analysis (it's an almost canonical one) that we run once before the interpreter starts.

I'm going to look at this once I understood the dataflow analysis framework and its API.

@matthias-springer matthias-springer merged commit e8ae0e7 into llvm:main Dec 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Failing test case with transform dialect due to potentially mis-tracked operation.
3 participants