Skip to content

[mlir][transform] TrackingListener: Improve dead handles detection #74290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,8 @@ class TransformState {
/// with the type of the handle value.
LogicalResult mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(argument.getParentRegion() == regionStack.back() &&
assert(argument.getParentRegion() == regionStack.back()->region &&
"mapping block arguments from a region other than the active one");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
LogicalResult mapBlockArgument(BlockArgument argument,
Expand Down Expand Up @@ -350,9 +348,7 @@ class TransformState {
std::make_pair(&region, std::make_unique<Mappings>()));
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.push_back(&region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.push_back(this);
}

/// Back-reference to the transform state.
Expand All @@ -361,7 +357,10 @@ class TransformState {
/// The region this scope is associated with.
Region *region;

friend RegionScope TransformState::make_region_scope(Region &);
/// The transform op within this region that is currently being applied.
TransformOpInterface currentTransform;

friend class transform::TransformState;
};
friend class RegionScope;

Expand Down Expand Up @@ -784,12 +783,14 @@ class TransformState {
/// location.
InvalidatedHandleMap invalidatedHandles;

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
SmallVector<Region *> regionStack;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
SmallVector<RegionScope *> regionStack;

/// The top-level region scope. The first (bottom) element of `regionStack`
/// is the top-level region scope object.
std::unique_ptr<RegionScope> topLevelRegionScope;
};

/// Local mapping between values defined by a specific op implementing the
Expand Down Expand Up @@ -926,8 +927,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
/// A function that returns "true" for handles that do not have to be updated.
using SkipHandleFn = std::function<bool(Value)>;

/// Create a new TrackingListener for usage in the specified transform op.
TrackingListener(TransformState &state, TransformOpInterface op);
/// Optionally, a function can be specified to identify handles that should
/// do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op,
SkipHandleFn skipHandleFn = nullptr);

protected:
/// Return a replacement payload op for the given op, which is going to be
Expand Down Expand Up @@ -1015,6 +1022,10 @@ class TrackingListener : public RewriterBase::Listener,

/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;

/// Handles for which this function evaluates to "true" do not have to be
/// updated. These are typically dead or consumed handles.
SkipHandleFn skipHandleFn;
};

/// A specialized listener that keeps track of cases in which no replacement
Expand Down
104 changes: 61 additions & 43 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@

using namespace mlir;

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//

/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b) {
do {
if (a->isProperAncestor(b))
return false;
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
return a->isBeforeInBlock(bAncestor);
}
} while ((a = a->getParentOp()));
return false;
}

//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
Expand All @@ -44,14 +61,10 @@ transform::TransformState::TransformState(
topLevelMappedValues.reserve(extraMappings.size());
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);

auto result =
mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
regionStack.push_back(region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
if (region) {
RegionScope *scope = new RegionScope(*this, *region);
topLevelRegionScope.reset(scope);
}
}

Operation *transform::TransformState::getTopLevel() const { return topLevel; }
Expand Down Expand Up @@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
});

// Set current transform op.
regionStack.back()->currentTransform = transform;

// Expensive checks to detect invalid transform IR.
if (options.getExpensiveChecksEnabled()) {
FULL_LDBG("ExpensiveChecksEnabled\n");
if (failed(checkAndRecordHandleInvalidation(transform)))
Expand Down Expand Up @@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}

// Prepare rewriter and listener.
transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
return handle.getParentRegion() == scope->region;
});
assert(scopeIt != regionStack.rend() &&
"could not find region scope for handle");
RegionScope *scope = *scopeIt;
for (Operation *user : handle.getUsers()) {
if (user != scope->currentTransform &&
!happensBefore(user, scope->currentTransform))
return false;
}
return true;
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
skipHandleFn);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);

Expand Down Expand Up @@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

state.mappings.erase(region);

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
//===----------------------------------------------------------------------===//

transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op)
: TransformState::Extension(state), transformOp(op) {
TransformOpInterface op,
SkipHandleFn skipHandleFn)
: TransformState::Extension(state), transformOp(op),
skipHandleFn(skipHandleFn) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
Expand Down Expand Up @@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
});
}

/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b) {
do {
if (a->isProperAncestor(b))
return false;
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
return a->isBeforeInBlock(bAncestor);
}
} while ((a = a->getParentOp()));
return false;
}

void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
Expand Down Expand Up @@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
[&](Value h) { return consumedHandles.contains(h); });
};

// Helper function to check if the handle is alive.
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
for (Value v : opHandles) {
for (OpOperand &use : v.getUses())
if (use.getOwner() != transformOp &&
!happensBefore(use.getOwner(), transformOp))
return &use;
}
return std::nullopt;
}();

if (!firstAliveUser.has_value() || handleWasConsumed()) {
// Check if there are any handles that must be updated.
Value aliveHandle;
if (skipHandleFn) {
auto it =
llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
aliveHandle = opHandles.front();
}
if (!aliveHandle || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
Expand All @@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
// If the op is tracked but no replacement op was found, send a
// notification.
if (!diag.succeeded()) {
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
<< "replacement is required because alive handle(s) exist "
<< "(first use in this op as operand number "
<< (*firstAliveUser)->getOperandNumber() << ")";
Comment on lines -1324 to -1325
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance to preserve this? Feels like useful information...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to, but it would make the API quite complex. SkipHandleFn is pretty simple in the current implementation: is the handle needed or not. It would have to return a user, but the name "skip handle function" would not make sense anymore and we would have to mention the concept of liveness in the TrackingListener.

diag.attachNote(aliveHandle.getLoc())
<< "replacement is required because this handle must be updated";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;
Expand Down
30 changes: 29 additions & 1 deletion mlir/test/Dialect/Transform/test-pattern-application.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-note @below {{replacement is required because this handle must be updated}}
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
// expected-note @below {{ran out of suitable replacement values}}
Expand All @@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
transform.annotate %1 "annotated" : !transform.any_op
}

Expand Down Expand Up @@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
legal_ops = ["func.func", "func.return", "test.new_op"]}
: !transform.any_op
}

// -----

module attributes { transform.with_named_sequence } {
func.func @replacement_op_not_found() {
// No op replacement can be found, but there are no handles that must be
// updated. No error should be reported.
"test.container"() ({
%0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
}) : () -> ()
return
}

transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
transform.apply_patterns to %container {
transform.apply_patterns.transform.test_patterns
} : !transform.any_op
transform.yield
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.annotate %1 "annotated" : !transform.any_op
transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
}
}