Skip to content

[mlir][transform] Improve error message of tracking listener. #66987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,8 @@ class TransformResults {
/// corresponds to the given list of payload IR ops. Each result must be set
/// by the transformation exactly once in case of transformation succeeding.
/// The value must have a type implementing TransformHandleTypeInterface.
template <typename Range> void set(OpResult value, Range &&ops) {
template <typename Range>
void set(OpResult value, Range &&ops) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(operations.size()) &&
"setting results for a non-existent handle");
Expand Down Expand Up @@ -929,8 +930,9 @@ class TrackingListener : public RewriterBase::Listener,
///
/// Derived classes may override `findReplacementOp` to specify custom
/// replacement rules.
virtual FailureOr<Operation *> findReplacementOp(Operation *op,
ValueRange newValues) const;
virtual DiagnosedSilenceableFailure
findReplacementOp(Operation *&result, Operation *op,
ValueRange newValues) const;

/// Notify the listener that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
Expand All @@ -942,8 +944,9 @@ class TrackingListener : public RewriterBase::Listener,
/// This function is called when a tracked payload op is dropped because no
/// replacement op was found. Derived classes can implement this function for
/// custom error handling.
virtual void notifyPayloadReplacementNotFound(Operation *op,
ValueRange values) {}
virtual void
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
DiagnosedSilenceableFailure &&diag) {}

/// Return the single op that defines all given values (if any).
static Operation *getCommonDefiningOp(ValueRange values);
Expand Down Expand Up @@ -983,8 +986,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
bool failed() const;

protected:
void notifyPayloadReplacementNotFound(Operation *op,
ValueRange values) override;
void
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
DiagnosedSilenceableFailure &&diag) override;

private:
/// The error state of this listener. "Success" indicates that no error
Expand Down
78 changes: 51 additions & 27 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,45 +1289,59 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
return defOp;
}

FailureOr<Operation *>
transform::TrackingListener::findReplacementOp(Operation *op,
ValueRange newValues) const {
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
Operation *&result, Operation *op, ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
SmallVector<Value> values(newValues.begin(), newValues.end());

DiagnosedSilenceableFailure diag = emitSilenceableFailure(
getTransformOp(), "tracking listener failed to find replacement op "
"during application of this transform op");

do {
// If the replacement values belong to different ops, drop the mapping.
Operation *defOp = getCommonDefiningOp(values);
if (!defOp)
return failure();
if (!defOp) {
diag.attachNote() << "replacement values belong to different ops";
return diag;
}

// If the defining op has the same type, we take it as a replacement.
if (op->getName() == defOp->getName())
return defOp;
if (op->getName() == defOp->getName()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}

// Replacing an op with a constant-like equivalent is a common
// canonicalization.
if (defOp->hasTrait<OpTrait::ConstantLike>())
return defOp;
if (defOp->hasTrait<OpTrait::ConstantLike>()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}

values.clear();

// Skip through ops that implement FindPayloadReplacementOpInterface.
if (auto findReplacementOpInterface =
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
values.assign(findReplacementOpInterface.getNextOperands());
diag.attachNote(defOp->getLoc()) << "using operands provided by "
"'FindPayloadReplacementOpInterface'";
continue;
}

// Skip through ops that implement CastOpInterface.
if (isa<CastOpInterface>(defOp)) {
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
diag.attachNote(defOp->getLoc())
<< "using output of 'CastOpInterface' op";
continue;
}
} while (!values.empty());

return failure();
diag.attachNote() << "ran out of suitable replacement values";
return diag;
}

LogicalResult transform::TrackingListener::notifyMatchFailure(
Expand Down Expand Up @@ -1396,32 +1410,39 @@ void transform::TrackingListener::notifyOperationReplaced(
};

// Helper function to check if the handle is alive.
auto hasAliveUser = [&]() {
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
for (Value v : opHandles) {
for (Operation *user : v.getUsers())
if (user != transformOp && !happensBefore(user, transformOp))
return true;
for (OpOperand &use : v.getUses())
if (use.getOwner() != transformOp &&
!happensBefore(use.getOwner(), transformOp))
return &use;
}
return false;
};
return std::nullopt;
}();

if (!hasAliveUser() || handleWasConsumed()) {
if (!firstAliveUser.has_value() || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
return;
}

FailureOr<Operation *> replacement = findReplacementOp(op, newValues);
Operation *replacement;
DiagnosedSilenceableFailure diag =
findReplacementOp(replacement, op, newValues);
// If the op is tracked but no replacement op was found, send a
// notification.
if (failed(replacement)) {
notifyPayloadReplacementNotFound(op, newValues);
if (!diag.succeeded()) {
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
<< "replacement is required because alive handle(s) exist "
<< "(first use in this op as operand number "
<< (*firstAliveUser)->getOperandNumber() << ")";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;
}

(void)replacePayloadOp(op, *replacement);
(void)replacePayloadOp(op, replacement);
}

transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
Expand All @@ -1444,17 +1465,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const {
}

void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
Operation *op, ValueRange values) {
if (status.succeeded()) {
status = emitSilenceableFailure(
getTransformOp(), "tracking listener failed to find replacement op");
Copy link
Member

@matthias-springer matthias-springer Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe "tracking listener failed to find replacement op for op that is tracked by non-dead handle".

But even that could be confusing. Maybe we should just add a comment in the C++ code that mentions all the conditions (type changed, handle not dead, etc.).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking literally returning DiagnosedSilecenableFailure instead of LogicalResult, then the replacement lookup can specify the message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of that is good input. I now assemble the error message in different places, each contributing the information they have. Please take another look.

}
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {

// Merge potentially existing diags and store the result in the listener.
SmallVector<Diagnostic> diags;
diag.takeDiagnostics(diags);
if (!status.succeeded())
status.takeDiagnostics(diags);
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));

// Report more details.
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
for (auto &&[index, value] : llvm::enumerate(values))
status.attachNote(value.getLoc())
<< "[" << errorCounter << "] replacement value " << index;
Comment on lines 1478 to 1481
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This information could (almost) be added to the diagnostics findReplacementOp; however, errorCounter is not available there. Are we sure we need this? (There are no tests where errorCounter > 0.) Note that some of the diagnostics that this PR currently adds also don't report the current error count...


++errorCounter;
}

Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Dialect/Transform/test-pattern-application.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{tracking listener failed to find replacement op}}
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
// expected-note @below {{ran out of suitable replacement values}}
transform.apply_patterns to %0 {
transform.apply_patterns.transform.test_patterns
} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
transform.annotate %1 "annotated" : !transform.any_op
}

Expand Down
23 changes: 16 additions & 7 deletions mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct TestTensorTransforms
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>();
registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect,
transform::TransformDialect>();
}

StringRef getArgument() const final {
Expand Down Expand Up @@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener {

// Expose `findReplacementOp` as a public function, so that it can be tested.
Operation *getReplacementOp(Operation *op, ValueRange newValues) const {
FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues);
if (failed(replacementOp))
Operation *replacementOp;
if (!findReplacementOp(replacementOp, op, newValues).succeeded())
return nullptr;
return *replacementOp;
return replacementOp;
}
};
} // namespace
Expand Down Expand Up @@ -352,8 +352,17 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) {
transform::TransformState transformState =
transform::detail::makeTransformStateForTesting(/*region=*/nullptr,
/*payloadRoot=*/nullptr);
DummyTrackingListener listener(transformState,
transform::TransformOpInterface());
MLIRContext *context = rootOp->getContext();
OpBuilder builder(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This OpBuilder has no insertion point. That's likely what caused the leak.

auto transformOp = builder.create<transform::NamedSequenceOp>(
rootOp->getLoc(),
/*sym_name=*/"test_sequence",
/*function_type=*/
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})),
/*sym_visibility*/ StringAttr::get(context, "public"),
/*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()),
/*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()));
DummyTrackingListener listener(transformState, transformOp);
Operation *replacement = listener.getReplacementOp(replaced, replacements);
if (!replacement) {
replaced->emitError("listener could not find replacement op");
Expand Down