-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][transform] Improve error message of tracking listener. #66987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1289,45 +1289,59 @@ Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { | |
return defOp; | ||
} | ||
|
||
FailureOr<Operation *> | ||
transform::TrackingListener::findReplacementOp(Operation *op, | ||
ValueRange newValues) const { | ||
DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( | ||
Operation *&result, Operation *op, ValueRange newValues) const { | ||
assert(op->getNumResults() == newValues.size() && | ||
"invalid number of replacement values"); | ||
SmallVector<Value> values(newValues.begin(), newValues.end()); | ||
|
||
DiagnosedSilenceableFailure diag = emitSilenceableFailure( | ||
getTransformOp(), "tracking listener failed to find replacement op " | ||
"during application of this transform op"); | ||
|
||
do { | ||
// If the replacement values belong to different ops, drop the mapping. | ||
Operation *defOp = getCommonDefiningOp(values); | ||
if (!defOp) | ||
return failure(); | ||
if (!defOp) { | ||
diag.attachNote() << "replacement values belong to different ops"; | ||
return diag; | ||
} | ||
|
||
// If the defining op has the same type, we take it as a replacement. | ||
if (op->getName() == defOp->getName()) | ||
return defOp; | ||
if (op->getName() == defOp->getName()) { | ||
result = defOp; | ||
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
||
// Replacing an op with a constant-like equivalent is a common | ||
// canonicalization. | ||
if (defOp->hasTrait<OpTrait::ConstantLike>()) | ||
return defOp; | ||
if (defOp->hasTrait<OpTrait::ConstantLike>()) { | ||
result = defOp; | ||
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
||
values.clear(); | ||
|
||
// Skip through ops that implement FindPayloadReplacementOpInterface. | ||
if (auto findReplacementOpInterface = | ||
dyn_cast<FindPayloadReplacementOpInterface>(defOp)) { | ||
values.assign(findReplacementOpInterface.getNextOperands()); | ||
diag.attachNote(defOp->getLoc()) << "using operands provided by " | ||
"'FindPayloadReplacementOpInterface'"; | ||
continue; | ||
} | ||
|
||
// Skip through ops that implement CastOpInterface. | ||
if (isa<CastOpInterface>(defOp)) { | ||
values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); | ||
diag.attachNote(defOp->getLoc()) | ||
<< "using output of 'CastOpInterface' op"; | ||
continue; | ||
} | ||
} while (!values.empty()); | ||
|
||
return failure(); | ||
diag.attachNote() << "ran out of suitable replacement values"; | ||
return diag; | ||
} | ||
|
||
LogicalResult transform::TrackingListener::notifyMatchFailure( | ||
|
@@ -1396,32 +1410,39 @@ void transform::TrackingListener::notifyOperationReplaced( | |
}; | ||
|
||
// Helper function to check if the handle is alive. | ||
auto hasAliveUser = [&]() { | ||
auto firstAliveUser = [&]() -> std::optional<OpOperand *> { | ||
for (Value v : opHandles) { | ||
for (Operation *user : v.getUsers()) | ||
if (user != transformOp && !happensBefore(user, transformOp)) | ||
return true; | ||
for (OpOperand &use : v.getUses()) | ||
if (use.getOwner() != transformOp && | ||
!happensBefore(use.getOwner(), transformOp)) | ||
return &use; | ||
} | ||
return false; | ||
}; | ||
return std::nullopt; | ||
}(); | ||
|
||
if (!hasAliveUser() || handleWasConsumed()) { | ||
if (!firstAliveUser.has_value() || handleWasConsumed()) { | ||
// The op is tracked but the corresponding handles are dead or were | ||
// consumed. Drop the op form the mapping. | ||
(void)replacePayloadOp(op, nullptr); | ||
return; | ||
} | ||
|
||
FailureOr<Operation *> replacement = findReplacementOp(op, newValues); | ||
Operation *replacement; | ||
DiagnosedSilenceableFailure diag = | ||
findReplacementOp(replacement, op, newValues); | ||
// If the op is tracked but no replacement op was found, send a | ||
// notification. | ||
if (failed(replacement)) { | ||
notifyPayloadReplacementNotFound(op, newValues); | ||
if (!diag.succeeded()) { | ||
diag.attachNote((*firstAliveUser)->getOwner()->getLoc()) | ||
<< "replacement is required because alive handle(s) exist " | ||
<< "(first use in this op as operand number " | ||
<< (*firstAliveUser)->getOperandNumber() << ")"; | ||
notifyPayloadReplacementNotFound(op, newValues, std::move(diag)); | ||
(void)replacePayloadOp(op, nullptr); | ||
return; | ||
} | ||
|
||
(void)replacePayloadOp(op, *replacement); | ||
(void)replacePayloadOp(op, replacement); | ||
} | ||
|
||
transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { | ||
|
@@ -1444,17 +1465,20 @@ bool transform::ErrorCheckingTrackingListener::failed() const { | |
} | ||
|
||
void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( | ||
Operation *op, ValueRange values) { | ||
if (status.succeeded()) { | ||
status = emitSilenceableFailure( | ||
getTransformOp(), "tracking listener failed to find replacement op"); | ||
} | ||
Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { | ||
|
||
// Merge potentially existing diags and store the result in the listener. | ||
SmallVector<Diagnostic> diags; | ||
diag.takeDiagnostics(diags); | ||
if (!status.succeeded()) | ||
status.takeDiagnostics(diags); | ||
status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); | ||
|
||
// Report more details. | ||
status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; | ||
for (auto &&[index, value] : llvm::enumerate(values)) | ||
status.attachNote(value.getLoc()) | ||
<< "[" << errorCounter << "] replacement value " << index; | ||
Comment on lines
1478
to
1481
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This information could (almost) be added to the diagnostics |
||
|
||
++errorCounter; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,8 +32,8 @@ struct TestTensorTransforms | |
TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry | ||
.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect>(); | ||
registry.insert<arith::ArithDialect, scf::SCFDialect, linalg::LinalgDialect, | ||
transform::TransformDialect>(); | ||
} | ||
|
||
StringRef getArgument() const final { | ||
|
@@ -292,10 +292,10 @@ class DummyTrackingListener : public transform::TrackingListener { | |
|
||
// Expose `findReplacementOp` as a public function, so that it can be tested. | ||
Operation *getReplacementOp(Operation *op, ValueRange newValues) const { | ||
FailureOr<Operation *> replacementOp = findReplacementOp(op, newValues); | ||
if (failed(replacementOp)) | ||
Operation *replacementOp; | ||
if (!findReplacementOp(replacementOp, op, newValues).succeeded()) | ||
return nullptr; | ||
return *replacementOp; | ||
return replacementOp; | ||
} | ||
}; | ||
} // namespace | ||
|
@@ -352,8 +352,17 @@ static LogicalResult testTrackingListenerReplacements(Operation *rootOp) { | |
transform::TransformState transformState = | ||
transform::detail::makeTransformStateForTesting(/*region=*/nullptr, | ||
/*payloadRoot=*/nullptr); | ||
DummyTrackingListener listener(transformState, | ||
transform::TransformOpInterface()); | ||
MLIRContext *context = rootOp->getContext(); | ||
OpBuilder builder(context); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This OpBuilder has no insertion point. That's likely what caused the leak. |
||
auto transformOp = builder.create<transform::NamedSequenceOp>( | ||
rootOp->getLoc(), | ||
/*sym_name=*/"test_sequence", | ||
/*function_type=*/ | ||
TypeAttr::get(FunctionType::get(context, TypeRange{}, TypeRange{})), | ||
/*sym_visibility*/ StringAttr::get(context, "public"), | ||
/*arg_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>()), | ||
/*res_attrs=*/ArrayAttr::get(context, ArrayRef<Attribute>())); | ||
DummyTrackingListener listener(transformState, transformOp); | ||
Operation *replacement = listener.getReplacementOp(replaced, replacements); | ||
if (!replacement) { | ||
replaced->emitError("listener could not find replacement op"); | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe "tracking listener failed to find replacement op for op that is tracked by non-dead handle".
But even that could be confusing. Maybe we should just add a comment in the C++ code that mentions all the conditions (type changed, handle not dead, etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking literally returning
DiagnosedSilecenableFailure
instead ofLogicalResult
, then the replacement lookup can specify the message.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of that is good input. I now assemble the error message in different places, each contributing the information they have. Please take another look.