@@ -139,6 +139,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
139
139
cl::desc (" Run the verifier after each transformation pass" ),
140
140
cl::location (verifyPassesFlag), cl::init (true ));
141
141
142
+ static cl::opt<bool , /* ExternalStorage=*/ true > verifyRoundtrip (
143
+ " verify-roundtrip" ,
144
+ cl::desc (" Round-trip the IR after parsing and ensure it succeeds" ),
145
+ cl::location (verifyRoundtripFlag), cl::init (false ));
146
+
142
147
static cl::list<std::string> passPlugins (
143
148
" load-pass-plugin" , cl::desc (" Load passes from plugin library" ));
144
149
// / Set the callback to load a pass plugin.
@@ -213,6 +218,104 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
213
218
});
214
219
}
215
220
221
+ LogicalResult loadIRDLDialects (StringRef irdlFile, MLIRContext &ctx) {
222
+ DialectRegistry registry;
223
+ registry.insert <irdl::IRDLDialect>();
224
+ ctx.appendDialectRegistry (registry);
225
+
226
+ // Set up the input file.
227
+ std::string errorMessage;
228
+ std::unique_ptr<MemoryBuffer> file = openInputFile (irdlFile, &errorMessage);
229
+ if (!file) {
230
+ emitError (UnknownLoc::get (&ctx)) << errorMessage;
231
+ return failure ();
232
+ }
233
+
234
+ // Give the buffer to the source manager.
235
+ // This will be picked up by the parser.
236
+ SourceMgr sourceMgr;
237
+ sourceMgr.AddNewSourceBuffer (std::move (file), SMLoc ());
238
+
239
+ SourceMgrDiagnosticHandler sourceMgrHandler (sourceMgr, &ctx);
240
+
241
+ // Parse the input file.
242
+ OwningOpRef<ModuleOp> module (parseSourceFile<ModuleOp>(sourceMgr, &ctx));
243
+
244
+ // Load IRDL dialects.
245
+ return irdl::loadDialects (module .get ());
246
+ }
247
+
248
+ // Return success if the module can correctly round-trip. This intended to test
249
+ // that the custom printers/parsers are complete.
250
+ static LogicalResult doVerifyRoundTrip (Operation *op,
251
+ const MlirOptMainConfig &config,
252
+ bool useBytecode) {
253
+ // We use a new context to avoid resource handle renaming issue in the diff.
254
+ MLIRContext roundtripContext;
255
+ OwningOpRef<Operation *> roundtripModule;
256
+ roundtripContext.appendDialectRegistry (
257
+ op->getContext ()->getDialectRegistry ());
258
+ if (op->getContext ()->allowsUnregisteredDialects ())
259
+ roundtripContext.allowUnregisteredDialects ();
260
+ StringRef irdlFile = config.getIrdlFile ();
261
+ if (!irdlFile.empty () && failed (loadIRDLDialects (irdlFile, roundtripContext)))
262
+ return failure ();
263
+
264
+ // Print a first time with custom format (or bytecode) and parse it back to
265
+ // the roundtripModule.
266
+ {
267
+ std::string buffer;
268
+ llvm::raw_string_ostream ostream (buffer);
269
+ if (useBytecode) {
270
+ if (failed (writeBytecodeToFile (op, ostream))) {
271
+ op->emitOpError () << " failed to write bytecode, cannot verify round-trip.\n " ;
272
+ return failure ();
273
+ }
274
+ } else {
275
+ op->print (ostream,
276
+ OpPrintingFlags ().printGenericOpForm (false ).enableDebugInfo ());
277
+ }
278
+ FallbackAsmResourceMap fallbackResourceMap;
279
+ ParserConfig parseConfig (&roundtripContext, /* verifyAfterParse=*/ true ,
280
+ &fallbackResourceMap);
281
+ roundtripModule =
282
+ parseSourceString<Operation *>(ostream.str (), parseConfig);
283
+ if (!roundtripModule) {
284
+ op->emitOpError () << " failed to parse bytecode back, cannot verify round-trip.\n " ;
285
+ return failure ();
286
+ }
287
+ }
288
+
289
+ // Print in the generic form for the reference module and the round-tripped
290
+ // one and compare the outputs.
291
+ std::string reference, roundtrip;
292
+ {
293
+ llvm::raw_string_ostream ostreamref (reference);
294
+ op->print (ostreamref,
295
+ OpPrintingFlags ().printGenericOpForm ().enableDebugInfo ());
296
+ llvm::raw_string_ostream ostreamrndtrip (roundtrip);
297
+ roundtripModule.get ()->print (
298
+ ostreamrndtrip,
299
+ OpPrintingFlags ().printGenericOpForm ().enableDebugInfo ());
300
+ }
301
+ if (reference != roundtrip) {
302
+ // TODO implement a diff.
303
+ return op->emitOpError () << " roundTrip testing roundtripped module differs from reference:\n <<<<<<Reference\n "
304
+ << reference << " \n =====\n "
305
+ << roundtrip << " \n >>>>>roundtripped\n " ;
306
+ }
307
+
308
+ return success ();
309
+ }
310
+
311
+ static LogicalResult doVerifyRoundTrip (Operation *op,
312
+ const MlirOptMainConfig &config) {
313
+ // Textual round-trip isn't fully robust at the moment (for example implicit
314
+ // terminator are losing location informations).
315
+
316
+ return doVerifyRoundTrip (op, config, /* useBytecode=*/ true );
317
+ }
318
+
216
319
// / Perform the actions on the input file indicated by the command line flags
217
320
// / within the specified context.
218
321
// /
@@ -247,10 +350,16 @@ performActions(raw_ostream &os,
247
350
TimingScope parserTiming = timing.nest (" Parser" );
248
351
OwningOpRef<Operation *> op = parseSourceFileForTool (
249
352
sourceMgr, parseConfig, !config.shouldUseExplicitModule ());
250
- context-> enableMultithreading (wasThreadingEnabled );
353
+ parserTiming. stop ( );
251
354
if (!op)
252
355
return failure ();
253
- parserTiming.stop ();
356
+
357
+ // Perform round-trip verification if requested
358
+ if (config.shouldVerifyRoundtrip () &&
359
+ failed (doVerifyRoundTrip (op.get (), config)))
360
+ return failure ();
361
+
362
+ context->enableMultithreading (wasThreadingEnabled);
254
363
255
364
// Prepare the pass manager, applying command-line and reproducer options.
256
365
PassManager pm (op.get ()->getName (), PassManager::Nesting::Implicit);
@@ -286,33 +395,6 @@ performActions(raw_ostream &os,
286
395
return success ();
287
396
}
288
397
289
- LogicalResult loadIRDLDialects (StringRef irdlFile, MLIRContext &ctx) {
290
- DialectRegistry registry;
291
- registry.insert <irdl::IRDLDialect>();
292
- ctx.appendDialectRegistry (registry);
293
-
294
- // Set up the input file.
295
- std::string errorMessage;
296
- std::unique_ptr<MemoryBuffer> file = openInputFile (irdlFile, &errorMessage);
297
- if (!file) {
298
- emitError (UnknownLoc::get (&ctx)) << errorMessage;
299
- return failure ();
300
- }
301
-
302
- // Give the buffer to the source manager.
303
- // This will be picked up by the parser.
304
- SourceMgr sourceMgr;
305
- sourceMgr.AddNewSourceBuffer (std::move (file), SMLoc ());
306
-
307
- SourceMgrDiagnosticHandler sourceMgrHandler (sourceMgr, &ctx);
308
-
309
- // Parse the input file.
310
- OwningOpRef<ModuleOp> module (parseSourceFile<ModuleOp>(sourceMgr, &ctx));
311
-
312
- // Load IRDL dialects.
313
- return irdl::loadDialects (module .get ());
314
- }
315
-
316
398
// / Parses the memory buffer. If successfully, run a series of passes against
317
399
// / it and print the result.
318
400
static LogicalResult processBuffer (raw_ostream &os,
0 commit comments