Skip to content

[mlir] implement -verify-diagnostics=only-expected #135131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions mlir/examples/transform-opt/mlir-transform-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,22 @@ struct MlirTransformOptCLOptions {
cl::desc("Allow operations coming from an unregistered dialect"),
cl::init(false)};

cl::opt<bool> verifyDiagnostics{
"verify-diagnostics",
cl::desc("Check that emitted diagnostics match expected-* lines "
"on the corresponding line"),
cl::init(false)};
cl::opt<mlir::SourceMgrDiagnosticVerifierHandler::Level> verifyDiagnostics{
"verify-diagnostics", llvm::cl::ValueOptional,
cl::desc("Check that emitted diagnostics match expected-* lines on the "
"corresponding line"),
cl::values(
clEnumValN(
mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "all",
"Check all diagnostics (expected, unexpected, near-misses)"),
// Implicit value: when passed with no arguments, e.g.
// `--verify-diagnostics` or `--verify-diagnostics=`.
clEnumValN(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment that you are repeating enum all for backward compatibility?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is probably better/more relevant in the non-example one.

mlir::SourceMgrDiagnosticVerifierHandler::Level::All, "",
"Check all diagnostics (expected, unexpected, near-misses)"),
clEnumValN(
mlir::SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
"only-expected", "Check only expected diagnostics"))};

cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
cl::init("-")};
Expand Down Expand Up @@ -102,12 +113,17 @@ class DiagnosticHandlerWrapper {

/// Constructs the diagnostic handler of the specified kind of the given
/// source manager and context.
DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr,
mlir::MLIRContext *context) {
if (kind == Kind::EmitDiagnostics)
DiagnosticHandlerWrapper(
Kind kind, llvm::SourceMgr &mgr, mlir::MLIRContext *context,
std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level> level =
{}) {
if (kind == Kind::EmitDiagnostics) {
handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
else
handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context);
} else {
assert(level.has_value() && "expected level");
handler =
new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context, *level);
}
}

/// This object is non-copyable but movable.
Expand Down Expand Up @@ -150,7 +166,9 @@ class TransformSourceMgr {
public:
/// Constructs the source manager indicating whether diagnostic messages will
/// be verified later on.
explicit TransformSourceMgr(bool verifyDiagnostics)
explicit TransformSourceMgr(
std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
verifyDiagnostics)
: verifyDiagnostics(verifyDiagnostics) {}

/// Deconstructs the source manager. Note that `checkResults` must have been
Expand Down Expand Up @@ -179,7 +197,8 @@ class TransformSourceMgr {
// verification needs to happen and store it.
if (verifyDiagnostics) {
diagHandlers.emplace_back(
DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context);
DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context,
verifyDiagnostics);
} else {
diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
mgr, &context);
Expand All @@ -204,7 +223,8 @@ class TransformSourceMgr {

private:
/// Indicates whether diagnostic message verification is requested.
const bool verifyDiagnostics;
const std::optional<mlir::SourceMgrDiagnosticVerifierHandler::Level>
verifyDiagnostics;

/// Indicates that diagnostic message verification has taken place, and the
/// deconstruction is therefore safe.
Expand Down Expand Up @@ -248,7 +268,9 @@ static llvm::LogicalResult processPayloadBuffer(
context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects);
mlir::ParserConfig config(&context);
TransformSourceMgr sourceMgr(
/*verifyDiagnostics=*/clOptions->verifyDiagnostics);
/*verifyDiagnostics=*/clOptions->verifyDiagnostics.getNumOccurrences()
? std::optional{clOptions->verifyDiagnostics.getValue()}
: std::nullopt);

// Parse the input buffer that will be used as transform payload.
mlir::OwningOpRef<mlir::Operation *> payloadRoot =
Expand Down
7 changes: 5 additions & 2 deletions mlir/include/mlir/IR/Diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,12 @@ struct SourceMgrDiagnosticVerifierHandlerImpl;
/// corresponding line of the source file.
class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
public:
enum class Level { None = 0, All, OnlyExpected };
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
raw_ostream &out);
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
raw_ostream &out,
Level level = Level::All);
SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
Level level = Level::All);
~SourceMgrDiagnosticVerifierHandler();

/// Returns the status of the handler and verifies that all expected
Expand Down
16 changes: 13 additions & 3 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,20 @@ class MlirOptMainConfig {

/// Set whether to check that emitted diagnostics match `expected-*` lines on
/// the corresponding line. This is meant for implementing diagnostic tests.
MlirOptMainConfig &verifyDiagnostics(bool verify) {
MlirOptMainConfig &
verifyDiagnostics(SourceMgrDiagnosticVerifierHandler::Level verify) {
verifyDiagnosticsFlag = verify;
return *this;
}
bool shouldVerifyDiagnostics() const { return verifyDiagnosticsFlag; }

bool shouldVerifyDiagnostics() const {
return verifyDiagnosticsFlag !=
SourceMgrDiagnosticVerifierHandler::Level::None;
}

SourceMgrDiagnosticVerifierHandler::Level verifyDiagnosticsLevel() const {
return verifyDiagnosticsFlag;
}

/// Set whether to run the verifier after each transformation pass.
MlirOptMainConfig &verifyPasses(bool verify) {
Expand Down Expand Up @@ -276,7 +285,8 @@ class MlirOptMainConfig {

/// Set whether to check that emitted diagnostics match `expected-*` lines on
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;
SourceMgrDiagnosticVerifierHandler::Level verifyDiagnosticsFlag =
SourceMgrDiagnosticVerifierHandler::Level::None;

/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
Expand Down
23 changes: 18 additions & 5 deletions mlir/lib/IR/Diagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ struct ExpectedDiag {
};

struct SourceMgrDiagnosticVerifierHandlerImpl {
SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
SourceMgrDiagnosticVerifierHandlerImpl(
SourceMgrDiagnosticVerifierHandler::Level level)
: status(success()), level(level) {}

/// Returns the expected diagnostics for the given source file.
std::optional<MutableArrayRef<ExpectedDiag>>
Expand All @@ -672,6 +674,10 @@ struct SourceMgrDiagnosticVerifierHandlerImpl {
computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
const llvm::MemoryBuffer *buf);

SourceMgrDiagnosticVerifierHandler::Level getVerifyLevel() const {
return level;
}

/// The current status of the verifier.
LogicalResult status;

Expand All @@ -685,6 +691,10 @@ struct SourceMgrDiagnosticVerifierHandlerImpl {
llvm::Regex expected =
llvm::Regex("expected-(error|note|remark|warning)(-re)? "
"*(@([+-][0-9]+|above|below|unknown))? *{{(.*)}}$");

/// Verification level.
SourceMgrDiagnosticVerifierHandler::Level level =
SourceMgrDiagnosticVerifierHandler::Level::All;
};
} // namespace detail
} // namespace mlir
Expand Down Expand Up @@ -803,9 +813,9 @@ SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
}

SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out, Level level)
: SourceMgrDiagnosticHandler(srcMgr, ctx, out),
impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
impl(new SourceMgrDiagnosticVerifierHandlerImpl(level)) {
// Compute the expected diagnostics for each of the current files in the
// source manager.
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
Expand All @@ -823,8 +833,8 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
}

SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
llvm::SourceMgr &srcMgr, MLIRContext *ctx)
: SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
llvm::SourceMgr &srcMgr, MLIRContext *ctx, Level level)
: SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs(), level) {}

SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
// Ensure that all expected diagnostics were handled.
Expand Down Expand Up @@ -898,6 +908,9 @@ void SourceMgrDiagnosticVerifierHandler::process(LocationAttr loc,
}
}

if (impl->getVerifyLevel() == Level::OnlyExpected)
return;

// Otherwise, emit an error for the near miss.
if (nearMiss)
mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
Expand Down
28 changes: 22 additions & 6 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,26 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Split marker to use for merging the ouput"),
cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker));

static cl::opt<bool, /*ExternalStorage=*/true> verifyDiagnostics(
"verify-diagnostics",
cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
cl::location(verifyDiagnosticsFlag), cl::init(false));
static cl::opt<SourceMgrDiagnosticVerifierHandler::Level,
/*ExternalStorage=*/true>
verifyDiagnostics{
"verify-diagnostics", llvm::cl::ValueOptional,
cl::desc("Check that emitted diagnostics match expected-* lines on "
"the corresponding line"),
cl::location(verifyDiagnosticsFlag),
cl::values(
clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All,
"all",
"Check all diagnostics (expected, unexpected, "
"near-misses)"),
// Implicit value: when passed with no arguments, e.g.
// `--verify-diagnostics` or `--verify-diagnostics=`.
clEnumValN(SourceMgrDiagnosticVerifierHandler::Level::All, "",
"Check all diagnostics (expected, unexpected, "
"near-misses)"),
clEnumValN(
SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
"only-expected", "Check only expected diagnostics"))};

static cl::opt<bool, /*ExternalStorage=*/true> verifyPasses(
"verify-each",
Expand Down Expand Up @@ -537,7 +552,8 @@ static LogicalResult processBuffer(raw_ostream &os,
return performActions(os, sourceMgr, &context, config);
}

SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
*sourceMgr, &context, config.verifyDiagnosticsLevel());

// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
Expand Down
30 changes: 21 additions & 9 deletions mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,23 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
"default marker and process each chunk independently"),
llvm::cl::init("")};

static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
static llvm::cl::opt<SourceMgrDiagnosticVerifierHandler::Level>
verifyDiagnostics{
"verify-diagnostics", llvm::cl::ValueOptional,
llvm::cl::desc("Check that emitted diagnostics match expected-* "
"lines on the corresponding line"),
llvm::cl::values(
clEnumValN(
SourceMgrDiagnosticVerifierHandler::Level::All, "all",
"Check all diagnostics (expected, unexpected, near-misses)"),
// Implicit value: when passed with no arguments, e.g.
// `--verify-diagnostics` or `--verify-diagnostics=`.
clEnumValN(
SourceMgrDiagnosticVerifierHandler::Level::All, "",
"Check all diagnostics (expected, unexpected, near-misses)"),
clEnumValN(
SourceMgrDiagnosticVerifierHandler::Level::OnlyExpected,
"only-expected", "Check only expected diagnostics"))};

static llvm::cl::opt<bool> errorDiagnosticsOnly(
"error-diagnostics-only",
Expand Down Expand Up @@ -149,17 +161,17 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,

MLIRContext context;
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
context.printOpOnDiagnostic(verifyDiagnostics.getNumOccurrences() == 0);
auto sourceMgr = std::make_shared<llvm::SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());

if (verifyDiagnostics) {
if (verifyDiagnostics.getNumOccurrences()) {
// In the diagnostic verification flow, we ignore whether the
// translation failed (in most cases, it is expected to fail) and we do
// not filter non-error diagnostics even if `errorDiagnosticsOnly` is
// set. Instead, we check if the diagnostics were produced as expected.
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr,
&context);
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
*sourceMgr, &context, verifyDiagnostics);
(void)(*translationRequested)(sourceMgr, os, &context);
result = sourceMgrHandler.verify();
} else if (errorDiagnosticsOnly) {
Expand Down
1 change: 1 addition & 0 deletions mlir/test/Pass/full_diagnostics.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics=all

// Test that all errors are reported.
// expected-error@below {{illegal operation}}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Pass/full_diagnostics_only_expected.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-pass-failure{gen-diagnostics}))' -verify-diagnostics=only-expected

// Test that only expected errors are reported.
// reports {{illegal operation}} but unchecked
func.func @TestAlwaysIllegalOperationPass1() {
return
}

// expected-error@+1 {{illegal operation}}
func.func @TestAlwaysIllegalOperationPass2() {
return
}

// reports {{illegal operation}} but unchecked
func.func @TestAlwaysIllegalOperationPass3() {
return
}
24 changes: 24 additions & 0 deletions mlir/test/mlir-translate/verify-only-expected.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Check that verify-diagnostics=only-expected passes with only one actual `expected-error`
// RUN: mlir-translate %s --allow-unregistered-dialect -verify-diagnostics=only-expected -split-input-file -mlir-to-llvmir

// Check that verify-diagnostics=all fails because we're missing two `expected-error`
// RUN: not mlir-translate %s --allow-unregistered-dialect -verify-diagnostics=all -split-input-file -mlir-to-llvmir 2>&1 | FileCheck %s --check-prefix=CHECK-VERIFY-ALL
// CHECK-VERIFY-ALL: unexpected error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: simple.terminator1
// CHECK-VERIFY-ALL: unexpected error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: simple.terminator3

llvm.func @trivial() {
"simple.terminator1"() : () -> ()
}

// -----

llvm.func @trivial() {
// expected-error @+1 {{cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: simple.terminator2}}
"simple.terminator2"() : () -> ()
}

// -----

llvm.func @trivial() {
"simple.terminator3"() : () -> ()
}