Skip to content

Commit a6d09d4

Browse files
committed
Add a -verify-roundtrip option to mlir-opt intended to validate custom printer/parser completeness
Running: MLIR_OPT_CHECK_IR_ROUNDTRIP=1 ninja check-mlir will now exercises all of our test with a round-trip to bytecode and a comparison for equality. Reviewed By: rriddle, ftynse, jpienaar Differential Revision: https://reviews.llvm.org/D90088
1 parent 7c847ac commit a6d09d4

File tree

5 files changed

+129
-33
lines changed

5 files changed

+129
-33
lines changed

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ class OpPrintingFlags {
11031103
OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false);
11041104

11051105
/// Always print operations in the generic form.
1106-
OpPrintingFlags &printGenericOpForm();
1106+
OpPrintingFlags &printGenericOpForm(bool enable = true);
11071107

11081108
/// Skip printing regions.
11091109
OpPrintingFlags &skipRegions(bool skip = true);

mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ class MlirOptMainConfig {
163163
}
164164
bool shouldVerifyPasses() const { return verifyPassesFlag; }
165165

166+
/// Set whether to run the verifier after each transformation pass.
167+
MlirOptMainConfig &verifyRoundtrip(bool verify) {
168+
verifyRoundtripFlag = verify;
169+
return *this;
170+
}
171+
bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }
172+
166173
protected:
167174
/// Allow operation with no registered dialects.
168175
/// This option is for convenience during testing only and discouraged in
@@ -212,6 +219,9 @@ class MlirOptMainConfig {
212219

213220
/// Run the verifier after each transformation pass.
214221
bool verifyPassesFlag = true;
222+
223+
/// Verify that the input IR round-trips perfectly.
224+
bool verifyRoundtripFlag = false;
215225
};
216226

217227
/// This defines the function type used to setup the pass manager. This can be

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
219219
}
220220

221221
/// Always print operations in the generic form.
222-
OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
223-
printGenericOpFormFlag = true;
222+
OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
223+
printGenericOpFormFlag = enable;
224224
return *this;
225225
}
226226

mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
139139
cl::desc("Run the verifier after each transformation pass"),
140140
cl::location(verifyPassesFlag), cl::init(true));
141141

142+
static cl::opt<bool, /*ExternalStorage=*/true> verifyRoundtrip(
143+
"verify-roundtrip",
144+
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
145+
cl::location(verifyRoundtripFlag), cl::init(false));
146+
142147
static cl::list<std::string> passPlugins(
143148
"load-pass-plugin", cl::desc("Load passes from plugin library"));
144149
/// Set the callback to load a pass plugin.
@@ -213,6 +218,104 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
213218
});
214219
}
215220

221+
LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
222+
DialectRegistry registry;
223+
registry.insert<irdl::IRDLDialect>();
224+
ctx.appendDialectRegistry(registry);
225+
226+
// Set up the input file.
227+
std::string errorMessage;
228+
std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
229+
if (!file) {
230+
emitError(UnknownLoc::get(&ctx)) << errorMessage;
231+
return failure();
232+
}
233+
234+
// Give the buffer to the source manager.
235+
// This will be picked up by the parser.
236+
SourceMgr sourceMgr;
237+
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
238+
239+
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
240+
241+
// Parse the input file.
242+
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
243+
244+
// Load IRDL dialects.
245+
return irdl::loadDialects(module.get());
246+
}
247+
248+
// Return success if the module can correctly round-trip. This intended to test
249+
// that the custom printers/parsers are complete.
250+
static LogicalResult doVerifyRoundTrip(Operation *op,
251+
const MlirOptMainConfig &config,
252+
bool useBytecode) {
253+
// We use a new context to avoid resource handle renaming issue in the diff.
254+
MLIRContext roundtripContext;
255+
OwningOpRef<Operation *> roundtripModule;
256+
roundtripContext.appendDialectRegistry(
257+
op->getContext()->getDialectRegistry());
258+
if (op->getContext()->allowsUnregisteredDialects())
259+
roundtripContext.allowUnregisteredDialects();
260+
StringRef irdlFile = config.getIrdlFile();
261+
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
262+
return failure();
263+
264+
// Print a first time with custom format (or bytecode) and parse it back to
265+
// the roundtripModule.
266+
{
267+
std::string buffer;
268+
llvm::raw_string_ostream ostream(buffer);
269+
if (useBytecode) {
270+
if (failed(writeBytecodeToFile(op, ostream))) {
271+
op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n";
272+
return failure();
273+
}
274+
} else {
275+
op->print(ostream,
276+
OpPrintingFlags().printGenericOpForm(false).enableDebugInfo());
277+
}
278+
FallbackAsmResourceMap fallbackResourceMap;
279+
ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true,
280+
&fallbackResourceMap);
281+
roundtripModule =
282+
parseSourceString<Operation *>(ostream.str(), parseConfig);
283+
if (!roundtripModule) {
284+
op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n";
285+
return failure();
286+
}
287+
}
288+
289+
// Print in the generic form for the reference module and the round-tripped
290+
// one and compare the outputs.
291+
std::string reference, roundtrip;
292+
{
293+
llvm::raw_string_ostream ostreamref(reference);
294+
op->print(ostreamref,
295+
OpPrintingFlags().printGenericOpForm().enableDebugInfo());
296+
llvm::raw_string_ostream ostreamrndtrip(roundtrip);
297+
roundtripModule.get()->print(
298+
ostreamrndtrip,
299+
OpPrintingFlags().printGenericOpForm().enableDebugInfo());
300+
}
301+
if (reference != roundtrip) {
302+
// TODO implement a diff.
303+
return op->emitOpError() << "roundTrip testing roundtripped module differs from reference:\n<<<<<<Reference\n"
304+
<< reference << "\n=====\n"
305+
<< roundtrip << "\n>>>>>roundtripped\n";
306+
}
307+
308+
return success();
309+
}
310+
311+
static LogicalResult doVerifyRoundTrip(Operation *op,
312+
const MlirOptMainConfig &config) {
313+
// Textual round-trip isn't fully robust at the moment (for example implicit
314+
// terminator are losing location informations).
315+
316+
return doVerifyRoundTrip(op, config, /*useBytecode=*/true);
317+
}
318+
216319
/// Perform the actions on the input file indicated by the command line flags
217320
/// within the specified context.
218321
///
@@ -247,10 +350,16 @@ performActions(raw_ostream &os,
247350
TimingScope parserTiming = timing.nest("Parser");
248351
OwningOpRef<Operation *> op = parseSourceFileForTool(
249352
sourceMgr, parseConfig, !config.shouldUseExplicitModule());
250-
context->enableMultithreading(wasThreadingEnabled);
353+
parserTiming.stop();
251354
if (!op)
252355
return failure();
253-
parserTiming.stop();
356+
357+
// Perform round-trip verification if requested
358+
if (config.shouldVerifyRoundtrip() &&
359+
failed(doVerifyRoundTrip(op.get(), config)))
360+
return failure();
361+
362+
context->enableMultithreading(wasThreadingEnabled);
254363

255364
// Prepare the pass manager, applying command-line and reproducer options.
256365
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
@@ -286,33 +395,6 @@ performActions(raw_ostream &os,
286395
return success();
287396
}
288397

289-
LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
290-
DialectRegistry registry;
291-
registry.insert<irdl::IRDLDialect>();
292-
ctx.appendDialectRegistry(registry);
293-
294-
// Set up the input file.
295-
std::string errorMessage;
296-
std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
297-
if (!file) {
298-
emitError(UnknownLoc::get(&ctx)) << errorMessage;
299-
return failure();
300-
}
301-
302-
// Give the buffer to the source manager.
303-
// This will be picked up by the parser.
304-
SourceMgr sourceMgr;
305-
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
306-
307-
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
308-
309-
// Parse the input file.
310-
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
311-
312-
// Load IRDL dialects.
313-
return irdl::loadDialects(module.get());
314-
}
315-
316398
/// Parses the memory buffer. If successfully, run a series of passes against
317399
/// it and print the result.
318400
static LogicalResult processBuffer(raw_ostream &os,

mlir/test/lit.cfg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def add_runtime(name):
6565

6666
tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
6767
tools = [
68-
'mlir-opt',
6968
'mlir-tblgen',
7069
'mlir-translate',
7170
'mlir-lsp-server',
@@ -125,6 +124,11 @@ def add_runtime(name):
125124
ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
126125
])
127126

127+
if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
128+
tools.extend([
129+
ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
130+
])
131+
128132
llvm_config.add_tool_substitutions(tools, tool_dirs)
129133

130134

0 commit comments

Comments
 (0)