Skip to content

Commit 78c8ab5

Browse files
committed
Revert "[mlir][transform] Improve error message of tracking listener. (#66987)"
Breaks https://lab.llvm.org/buildbot/#/builders/5/builds/36953 This reverts commit a753045.
1 parent 7b70af2 commit 78c8ab5

File tree

4 files changed

+42
-81
lines changed

4 files changed

+42
-81
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -797,8 +797,7 @@ class TransformResults {
797797
/// corresponds to the given list of payload IR ops. Each result must be set
798798
/// by the transformation exactly once in case of transformation succeeding.
799799
/// The value must have a type implementing TransformHandleTypeInterface.
800-
template <typename Range>
801-
void set(OpResult value, Range &&ops) {
800+
template <typename Range> void set(OpResult value, Range &&ops) {
802801
int64_t position = value.getResultNumber();
803802
assert(position < static_cast<int64_t>(operations.size()) &&
804803
"setting results for a non-existent handle");
@@ -973,9 +972,8 @@ class TrackingListener : public RewriterBase::Listener,
973972
///
974973
/// Derived classes may override `findReplacementOp` to specify custom
975974
/// replacement rules.
976-
virtual DiagnosedSilenceableFailure
977-
findReplacementOp(Operation *&result, Operation *op,
978-
ValueRange newValues) const;
975+
virtual FailureOr<Operation *> findReplacementOp(Operation *op,
976+
ValueRange newValues) const;
979977

980978
/// Notify the listener that the pattern failed to match the given operation,
981979
/// and provide a callback to populate a diagnostic with the reason why the
@@ -987,9 +985,8 @@ class TrackingListener : public RewriterBase::Listener,
987985
/// This function is called when a tracked payload op is dropped because no
988986
/// replacement op was found. Derived classes can implement this function for
989987
/// custom error handling.
990-
virtual void
991-
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
992-
DiagnosedSilenceableFailure &&diag) {}
988+
virtual void notifyPayloadReplacementNotFound(Operation *op,
989+
ValueRange values) {}
993990

994991
/// Return the single op that defines all given values (if any).
995992
static Operation *getCommonDefiningOp(ValueRange values);
@@ -1029,9 +1026,8 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10291026
bool failed() const;
10301027

10311028
protected:
1032-
void
1033-
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
1034-
DiagnosedSilenceableFailure &&diag) override;
1029+
void notifyPayloadReplacementNotFound(Operation *op,
1030+
ValueRange values) override;
10351031

10361032
private:
10371033
/// The error state of this listener. "Success" indicates that no error

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,59 +1291,45 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
12911291
return defOp;
12921292
}
12931293

1294-
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1295-
Operation *&result, Operation *op, ValueRange newValues) const {
1294+
FailureOr<Operation *>
1295+
transform::TrackingListener::findReplacementOp(Operation *op,
1296+
ValueRange newValues) const {
12961297
assert(op->getNumResults() == newValues.size() &&
12971298
"invalid number of replacement values");
12981299
SmallVector<Value> values(newValues.begin(), newValues.end());
12991300

1300-
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1301-
getTransformOp(), "tracking listener failed to find replacement op "
1302-
"during application of this transform op");
1303-
13041301
do {
13051302
// If the replacement values belong to different ops, drop the mapping.
13061303
Operation *defOp = getCommonDefiningOp(values);
1307-
if (!defOp) {
1308-
diag.attachNote() << "replacement values belong to different ops";
1309-
return diag;
1310-
}
1304+
if (!defOp)
1305+
return failure();
13111306

13121307
// If the defining op has the same type, we take it as a replacement.
1313-
if (op->getName() == defOp->getName()) {
1314-
result = defOp;
1315-
return DiagnosedSilenceableFailure::success();
1316-
}
1308+
if (op->getName() == defOp->getName())
1309+
return defOp;
13171310

13181311
// Replacing an op with a constant-like equivalent is a common
13191312
// canonicalization.
1320-
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1321-
result = defOp;
1322-
return DiagnosedSilenceableFailure::success();
1323-
}
1313+
if (defOp->hasTrait<OpTrait::ConstantLike>())
1314+
return defOp;
13241315

13251316
values.clear();
13261317

13271318
// Skip through ops that implement FindPayloadReplacementOpInterface.
13281319
if (auto findReplacementOpInterface =
13291320
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
13301321
values.assign(findReplacementOpInterface.getNextOperands());
1331-
diag.attachNote(defOp->getLoc()) << "using operands provided by "
1332-
"'FindPayloadReplacementOpInterface'";
13331322
continue;
13341323
}
13351324

13361325
// Skip through ops that implement CastOpInterface.
13371326
if (isa<CastOpInterface>(defOp)) {
13381327
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1339-
diag.attachNote(defOp->getLoc())
1340-
<< "using output of 'CastOpInterface' op";
13411328
continue;
13421329
}
13431330
} while (!values.empty());
13441331

1345-
diag.attachNote() << "ran out of suitable replacement values";
1346-
return diag;
1332+
return failure();
13471333
}
13481334

13491335
LogicalResult transform::TrackingListener::notifyMatchFailure(
@@ -1412,39 +1398,32 @@ void transform::TrackingListener::notifyOperationReplaced(
14121398
};
14131399

14141400
// Helper function to check if the handle is alive.
1415-
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
1401+
auto hasAliveUser = [&]() {
14161402
for (Value v : opHandles) {
1417-
for (OpOperand &use : v.getUses())
1418-
if (use.getOwner() != transformOp &&
1419-
!happensBefore(use.getOwner(), transformOp))
1420-
return &use;
1403+
for (Operation *user : v.getUsers())
1404+
if (user != transformOp && !happensBefore(user, transformOp))
1405+
return true;
14211406
}
1422-
return std::nullopt;
1423-
}();
1407+
return false;
1408+
};
14241409

1425-
if (!firstAliveUser.has_value() || handleWasConsumed()) {
1410+
if (!hasAliveUser() || handleWasConsumed()) {
14261411
// The op is tracked but the corresponding handles are dead or were
14271412
// consumed. Drop the op form the mapping.
14281413
(void)replacePayloadOp(op, nullptr);
14291414
return;
14301415
}
14311416

1432-
Operation *replacement;
1433-
DiagnosedSilenceableFailure diag =
1434-
findReplacementOp(replacement, op, newValues);
1417+
FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
14351418
// If the op is tracked but no replacement op was found, send a
14361419
// notification.
1437-
if (!diag.succeeded()) {
1438-
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
1439-
<< "replacement is required because alive handle(s) exist "
1440-
<< "(first use in this op as operand number "
1441-
<< (*firstAliveUser)->getOperandNumber() << ")";
1442-
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1420+
if (failed(replacement)) {
1421+
notifyPayloadReplacementNotFound(op, newValues);
14431422
(void)replacePayloadOp(op, nullptr);
14441423
return;
14451424
}
14461425

1447-
(void)replacePayloadOp(op, replacement);
1426+
(void)replacePayloadOp(op, *replacement);
14481427
}
14491428

14501429
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
@@ -1467,20 +1446,17 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
14671446
}
14681447

14691448
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1470-
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1471-
1472-
// Merge potentially existing diags and store the result in the listener.
1473-
SmallVector<Diagnostic> diags;
1474-
diag.takeDiagnostics(diags);
1475-
if (!status.succeeded())
1476-
status.takeDiagnostics(diags);
1477-
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1449+
Operation *op, ValueRange values) {
1450+
if (status.succeeded()) {
1451+
status = emitSilenceableFailure(
1452+
getTransformOp(), "tracking listener failed to find replacement op");
1453+
}
14781454

1479-
// Report more details.
14801455
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
14811456
for (auto &&[index, value] : llvm::enumerate(values))
14821457
status.attachNote(value.getLoc())
14831458
<< "[" << errorCounter << "] replacement value " << index;
1459+
14841460
++errorCounter;
14851461
}
14861462

mlir/test/Dialect/Transform/test-pattern-application.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ transform.sequence failures(propagate) {
3737
^bb1(%arg1: !transform.any_op):
3838
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
3939
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
40-
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
41-
// expected-note @below {{ran out of suitable replacement values}}
40+
// expected-error @below {{tracking listener failed to find replacement op}}
4241
transform.apply_patterns to %0 {
4342
transform.apply_patterns.transform.test_patterns
4443
} : !transform.any_op
4544
// %1 must be used in some way. If no replacement payload op could be found,
4645
// an error is thrown only if the handle is not dead.
47-
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
4846
transform.annotate %1 "annotated" : !transform.any_op
4947
}
5048

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ struct TestTensorTransforms
3232
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
3333

3434
void getDependentDialects(DialectRegistry &registry) const override {
35-
registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
36-
transform::TransformDialect>();
35+
registry
36+
.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
3737
}
3838

3939
StringRef getArgument() const final {
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {
292292

293293
// Expose `findReplacementOp` as a public function, so that it can be tested.
294294
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
295-
Operation *replacementOp;
296-
if (!findReplacementOp(replacementOp, op, newValues).succeeded())
295+
FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
296+
if (failed(replacementOp))
297297
return nullptr;
298-
return replacementOp;
298+
return *replacementOp;
299299
}
300300
};
301301
} // namespace
@@ -352,17 +352,8 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
352352
transform::TransformState transformState =
353353
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
354354
/*payloadRoot=*/nullptr);
355-
MLIRContext *context = rootOp->getContext();
356-
OpBuilder builder(context);
357-
auto transformOp = builder.create<transform::NamedSequenceOp>(
358-
rootOp->getLoc(),
359-
/*sym_name=*/"test_sequence",
360-
/*function_type=*/
361-
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
362-
/*sym_visibility*/ StringAttr::get(context, "public"),
363-
/*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
364-
/*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
365-
DummyTrackingListener listener(transformState, transformOp);
355+
DummyTrackingListener listener(transformState,
356+
transform::TransformOpInterface());
366357
Operation *replacement = listener.getReplacementOp(replaced, replacements);
367358
if (!replacement) {
368359
replaced->emitError("listener could not find replacement op");

0 commit comments

Comments
 (0)