Skip to content

Commit f40e620

Browse files
Reapply "[mlir][transform] Improve error message of tracking listener. (#66987)"
This commit reapplies #66987, which got original contained a memory leak and got reverted by 78c8ab5. The leak is now fixed. Original description: This PR extends the error message of the tracking listener when replacement ops cannot be found. That may happen if the applied patterns replace an op by an op of a different kind or by block arguments. However, this only matters if there are alive handles to the replaced op. The new error message mentions that explicitly and reports the alive handles.
1 parent 0219bd3 commit f40e620

File tree

4 files changed

+82
-42
lines changed

4 files changed

+82
-42
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,8 @@ class TransformResults {
797797
/// corresponds to the given list of payload IR ops. Each result must be set
798798
/// by the transformation exactly once in case of transformation succeeding.
799799
/// The value must have a type implementing TransformHandleTypeInterface.
800-
template <typename Range> void set(OpResult value, Range &&ops) {
800+
template <typename Range>
801+
void set(OpResult value, Range &&ops) {
801802
int64_t position = value.getResultNumber();
802803
assert(position < static_cast<int64_t>(operations.size()) &&
803804
"setting results for a non-existent handle");
@@ -972,8 +973,9 @@ class TrackingListener : public RewriterBase::Listener,
972973
///
973974
/// Derived classes may override `findReplacementOp` to specify custom
974975
/// replacement rules.
975-
virtual FailureOr<Operation *> findReplacementOp(Operation *op,
976-
ValueRange newValues) const;
976+
virtual DiagnosedSilenceableFailure
977+
findReplacementOp(Operation *&result, Operation *op,
978+
ValueRange newValues) const;
977979

978980
/// Notify the listener that the pattern failed to match the given operation,
979981
/// and provide a callback to populate a diagnostic with the reason why the
@@ -985,8 +987,9 @@ class TrackingListener : public RewriterBase::Listener,
985987
/// This function is called when a tracked payload op is dropped because no
986988
/// replacement op was found. Derived classes can implement this function for
987989
/// custom error handling.
988-
virtual void notifyPayloadReplacementNotFound(Operation *op,
989-
ValueRange values) {}
990+
virtual void
991+
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
992+
DiagnosedSilenceableFailure &&diag) {}
990993

991994
/// Return the single op that defines all given values (if any).
992995
static Operation *getCommonDefiningOp(ValueRange values);
@@ -1026,8 +1029,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10261029
bool failed() const;
10271030

10281031
protected:
1029-
void notifyPayloadReplacementNotFound(Operation *op,
1030-
ValueRange values) override;
1032+
void
1033+
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
1034+
DiagnosedSilenceableFailure &&diag) override;
10311035

10321036
private:
10331037
/// The error state of this listener. "Success" indicates that no error

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,45 +1291,59 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
12911291
return defOp;
12921292
}
12931293

1294-
FailureOr<Operation *>
1295-
transform::TrackingListener::findReplacementOp(Operation *op,
1296-
ValueRange newValues) const {
1294+
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1295+
Operation *&result, Operation *op, ValueRange newValues) const {
12971296
assert(op->getNumResults() == newValues.size() &&
12981297
"invalid number of replacement values");
12991298
SmallVector<Value> values(newValues.begin(), newValues.end());
13001299

1300+
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1301+
getTransformOp(), "tracking listener failed to find replacement op "
1302+
"during application of this transform op");
1303+
13011304
do {
13021305
// If the replacement values belong to different ops, drop the mapping.
13031306
Operation *defOp = getCommonDefiningOp(values);
1304-
if (!defOp)
1305-
return failure();
1307+
if (!defOp) {
1308+
diag.attachNote() << "replacement values belong to different ops";
1309+
return diag;
1310+
}
13061311

13071312
// If the defining op has the same type, we take it as a replacement.
1308-
if (op->getName() == defOp->getName())
1309-
return defOp;
1313+
if (op->getName() == defOp->getName()) {
1314+
result = defOp;
1315+
return DiagnosedSilenceableFailure::success();
1316+
}
13101317

13111318
// Replacing an op with a constant-like equivalent is a common
13121319
// canonicalization.
1313-
if (defOp->hasTrait<OpTrait::ConstantLike>())
1314-
return defOp;
1320+
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1321+
result = defOp;
1322+
return DiagnosedSilenceableFailure::success();
1323+
}
13151324

13161325
values.clear();
13171326

13181327
// Skip through ops that implement FindPayloadReplacementOpInterface.
13191328
if (auto findReplacementOpInterface =
13201329
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
13211330
values.assign(findReplacementOpInterface.getNextOperands());
1331+
diag.attachNote(defOp->getLoc()) << "using operands provided by "
1332+
"'FindPayloadReplacementOpInterface'";
13221333
continue;
13231334
}
13241335

13251336
// Skip through ops that implement CastOpInterface.
13261337
if (isa<CastOpInterface>(defOp)) {
13271338
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1339+
diag.attachNote(defOp->getLoc())
1340+
<< "using output of 'CastOpInterface' op";
13281341
continue;
13291342
}
13301343
} while (!values.empty());
13311344

1332-
return failure();
1345+
diag.attachNote() << "ran out of suitable replacement values";
1346+
return diag;
13331347
}
13341348

13351349
LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1398,32 +1412,39 @@ void transform::TrackingListener::notifyOperationReplaced(
13981412
};
13991413

14001414
// Helper function to check if the handle is alive.
1401-
auto hasAliveUser = [&]() {
1415+
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
14021416
for (Value v : opHandles) {
1403-
for (Operation *user : v.getUsers())
1404-
if (user != transformOp && !happensBefore(user, transformOp))
1405-
return true;
1417+
for (OpOperand &use : v.getUses())
1418+
if (use.getOwner() != transformOp &&
1419+
!happensBefore(use.getOwner(), transformOp))
1420+
return &use;
14061421
}
1407-
return false;
1408-
};
1422+
return std::nullopt;
1423+
}();
14091424

1410-
if (!hasAliveUser() || handleWasConsumed()) {
1425+
if (!firstAliveUser.has_value() || handleWasConsumed()) {
14111426
// The op is tracked but the corresponding handles are dead or were
14121427
// consumed. Drop the op form the mapping.
14131428
(void)replacePayloadOp(op, nullptr);
14141429
return;
14151430
}
14161431

1417-
FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
1432+
Operation *replacement;
1433+
DiagnosedSilenceableFailure diag =
1434+
findReplacementOp(replacement, op, newValues);
14181435
// If the op is tracked but no replacement op was found, send a
14191436
// notification.
1420-
if (failed(replacement)) {
1421-
notifyPayloadReplacementNotFound(op, newValues);
1437+
if (!diag.succeeded()) {
1438+
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
1439+
<< "replacement is required because alive handle(s) exist "
1440+
<< "(first use in this op as operand number "
1441+
<< (*firstAliveUser)->getOperandNumber() << ")";
1442+
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
14221443
(void)replacePayloadOp(op, nullptr);
14231444
return;
14241445
}
14251446

1426-
(void)replacePayloadOp(op, *replacement);
1447+
(void)replacePayloadOp(op, replacement);
14271448
}
14281449

14291450
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1446,17 +1467,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
14461467
}
14471468

14481469
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1449-
Operation *op, ValueRange values) {
1450-
if (status.succeeded()) {
1451-
status = emitSilenceableFailure(
1452-
getTransformOp(), "tracking listener failed to find replacement op");
1453-
}
1470+
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1471+
1472+
// Merge potentially existing diags and store the result in the listener.
1473+
SmallVector<Diagnostic> diags;
1474+
diag.takeDiagnostics(diags);
1475+
if (!status.succeeded())
1476+
status.takeDiagnostics(diags);
1477+
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
14541478

1479+
// Report more details.
14551480
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
14561481
for (auto &&[index, value] : llvm::enumerate(values))
14571482
status.attachNote(value.getLoc())
14581483
<< "[" << errorCounter << "] replacement value " << index;
1459-
14601484
++errorCounter;
14611485
}
14621486

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ transform.sequence failures(propagate) {
3737
^bb1(%arg1: !transform.any_op):
3838
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
3939
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
40-
// expected-error @below {{tracking listener failed to find replacement op}}
40+
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
41+
// expected-note @below {{ran out of suitable replacement values}}
4142
transform.apply_patterns to %0 {
4243
transform.apply_patterns.transform.test_patterns
4344
} : !transform.any_op
4445
// %1 must be used in some way. If no replacement payload op could be found,
4546
// an error is thrown only if the handle is not dead.
47+
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
4648
transform.annotate %1 "annotated" : !transform.any_op
4749
}
4850

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ struct TestTensorTransforms
3232
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
3333

3434
void getDependentDialects(DialectRegistry &registry) const override {
35-
registry
36-
.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
35+
registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36+
transform::TransformDialect>();
3737
}
3838

3939
StringRef getArgument() const final {
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
292292

293293
// Expose `findReplacementOp` as a public function, so that it can be tested.
294294
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
295-
FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
296-
if (failed(replacementOp))
295+
Operation *replacementOp;
296+
if (!findReplacementOp(replacementOp, op, newValues).succeeded())
297297
return nullptr;
298-
return *replacementOp;
298+
return replacementOp;
299299
}
300300
};
301301
} // namespace
@@ -352,8 +352,18 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
352352
transform::TransformState transformState =
353353
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
354354
/*payloadRoot=*/nullptr);
355-
DummyTrackingListener listener(transformState,
356-
transform::TransformOpInterface());
355+
MLIRContext *context = rootOp->getContext();
356+
OpBuilder builder(context);
357+
OwningOpRef<transform::NamedSequenceOp> transformOp =
358+
builder.create<transform::NamedSequenceOp>(
359+
rootOp->getLoc(),
360+
/*sym_name=*/"test_sequence",
361+
/*function_type=*/
362+
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
363+
/*sym_visibility*/ StringAttr::get(context, "public"),
364+
/*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
365+
/*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
366+
DummyTrackingListener listener(transformState, transformOp.get());
357367
Operation *replacement = listener.getReplacementOp(replaced, replacements);
358368
if (!replacement) {
359369
replaced->emitError("listener could not find replacement op");

0 commit comments

Comments
 (0)