Skip to content

Commit 03e29a4

Browse files
authored
[mlir][Pass] Enable the option for reproducer generation without crashing (#75421)
This PR adds API `makeReproducer` and cl::opt flag `--mlir-generate-reproducer=<filename>` in order to allow for mlir reproducer dumps even when the pipeline doesn't crash. This PR also decouples the code that handles generation of an MLIR reproducer from the crash recovery portion. The purpose is to allow for generating reproducers outside of the context of a compiler crash. This will be useful for frameworks and runtimes that use MLIR where it is needed to reproduce the pipeline behavior for reasons outside of diagnosing crashes. An example is for diagnosing performance issues using offline tools, where being able to dump the reproducer from a runtime compiler would be helpful.
1 parent 76cb0bb commit 03e29a4

File tree

7 files changed

+127
-53
lines changed

7 files changed

+127
-53
lines changed

mlir/include/mlir/Pass/PassManager.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,27 @@ enum class PassDisplayMode {
207207
Pipeline,
208208
};
209209

210+
/// Streams on which to output crash reproducer.
211+
struct ReproducerStream {
212+
virtual ~ReproducerStream() = default;
213+
214+
/// Description of the reproducer stream.
215+
virtual StringRef description() = 0;
216+
217+
/// Stream on which to output reproducer.
218+
virtual raw_ostream &os() = 0;
219+
};
220+
221+
/// Method type for constructing ReproducerStream.
222+
using ReproducerStreamFactory =
223+
std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;
224+
225+
std::string
226+
makeReproducer(StringRef anchorName,
227+
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
228+
Operation *op, StringRef outputFile, bool disableThreads = false,
229+
bool verifyPasses = false);
230+
210231
/// The main pass manager and pipeline builder.
211232
class PassManager : public OpPassManager {
212233
public:
@@ -243,21 +264,6 @@ class PassManager : public OpPassManager {
243264
void enableCrashReproducerGeneration(StringRef outputFile,
244265
bool genLocalReproducer = false);
245266

246-
/// Streams on which to output crash reproducer.
247-
struct ReproducerStream {
248-
virtual ~ReproducerStream() = default;
249-
250-
/// Description of the reproducer stream.
251-
virtual StringRef description() = 0;
252-
253-
/// Stream on which to output reproducer.
254-
virtual raw_ostream &os() = 0;
255-
};
256-
257-
/// Method type for constructing ReproducerStream.
258-
using ReproducerStreamFactory =
259-
std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;
260-
261267
/// Enable support for the pass manager to generate a reproducer on the event
262268
/// of a crash or a pass failure. `factory` is used to construct the streams
263269
/// to write the generated reproducer to. If `genLocalReproducer` is true, the

mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ class MlirOptMainConfig {
173173
}
174174
bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }
175175

176+
/// Reproducer file generation (no crash required).
177+
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
178+
176179
protected:
177180
/// Allow operation with no registered dialects.
178181
/// This option is for convenience during testing only and discouraged in
@@ -228,6 +231,9 @@ class MlirOptMainConfig {
228231

229232
/// Verify that the input IR round-trips perfectly.
230233
bool verifyRoundtripFlag = false;
234+
235+
/// The reproducer output filename (no crash required).
236+
std::string generateReproducerFileFlag = "";
231237
};
232238

233239
/// This defines the function type used to setup the pass manager. This can be

mlir/lib/Pass/Pass.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,16 +382,22 @@ StringRef OpPassManager::getOpAnchorName() const {
382382

383383
/// Prints out the passes of the pass manager as the textual representation
384384
/// of pipelines.
385-
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
386-
os << getOpAnchorName() << "(";
385+
void printAsTextualPipeline(
386+
raw_ostream &os, StringRef anchorName,
387+
const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
388+
os << anchorName << "(";
387389
llvm::interleave(
388-
impl->passes,
389-
[&](const std::unique_ptr<Pass> &pass) {
390-
pass->printAsTextualPipeline(os);
391-
},
390+
passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
392391
[&]() { os << ","; });
393392
os << ")";
394393
}
394+
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
395+
StringRef anchorName = getOpAnchorName();
396+
::printAsTextualPipeline(
397+
os, anchorName,
398+
{MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
399+
MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
400+
}
395401

396402
void OpPassManager::dump() {
397403
llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";

mlir/lib/Pass/PassCrashRecovery.cpp

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace detail {
3838
/// reproducers when a signal is raised, such as a segfault.
3939
struct RecoveryReproducerContext {
4040
RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
41-
PassManager::ReproducerStreamFactory &streamFactory,
41+
ReproducerStreamFactory &streamFactory,
4242
bool verifyPasses);
4343
~RecoveryReproducerContext();
4444

@@ -67,7 +67,7 @@ struct RecoveryReproducerContext {
6767

6868
/// The factory for the reproducer output stream to use when generating the
6969
/// reproducer.
70-
PassManager::ReproducerStreamFactory &streamFactory;
70+
ReproducerStreamFactory &streamFactory;
7171

7272
/// Various pass manager and context flags.
7373
bool disableThreads;
@@ -92,7 +92,7 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
9292

9393
RecoveryReproducerContext::RecoveryReproducerContext(
9494
std::string passPipelineStr, Operation *op,
95-
PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
95+
ReproducerStreamFactory &streamFactory, bool verifyPasses)
9696
: pipelineElements(std::move(passPipelineStr)),
9797
preCrashOperation(op->clone()), streamFactory(streamFactory),
9898
disableThreads(!op->getContext()->isMultithreadingEnabled()),
@@ -106,22 +106,24 @@ RecoveryReproducerContext::~RecoveryReproducerContext() {
106106
disable();
107107
}
108108

109-
void RecoveryReproducerContext::generate(std::string &description) {
109+
static void appendReproducer(std::string &description, Operation *op,
110+
const ReproducerStreamFactory &factory,
111+
const std::string &pipelineElements,
112+
bool disableThreads, bool verifyPasses) {
110113
llvm::raw_string_ostream descOS(description);
111114

112115
// Try to create a new output stream for this crash reproducer.
113116
std::string error;
114-
std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
117+
std::unique_ptr<ReproducerStream> stream = factory(error);
115118
if (!stream) {
116119
descOS << "failed to create output stream: " << error;
117120
return;
118121
}
119122
descOS << "reproducer generated at `" << stream->description() << "`";
120123

121-
std::string pipeline = (preCrashOperation->getName().getStringRef() + "(" +
122-
pipelineElements + ")")
123-
.str();
124-
AsmState state(preCrashOperation);
124+
std::string pipeline =
125+
(op->getName().getStringRef() + "(" + pipelineElements + ")").str();
126+
AsmState state(op);
125127
state.attachResourcePrinter(
126128
"mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
127129
builder.buildString("pipeline", pipeline);
@@ -130,7 +132,12 @@ void RecoveryReproducerContext::generate(std::string &description) {
130132
});
131133

132134
// Output the .mlir module.
133-
preCrashOperation->print(stream->os(), state);
135+
op->print(stream->os(), state);
136+
}
137+
138+
void RecoveryReproducerContext::generate(std::string &description) {
139+
appendReproducer(description, preCrashOperation, streamFactory,
140+
pipelineElements, disableThreads, verifyPasses);
134141
}
135142

136143
void RecoveryReproducerContext::disable() {
@@ -175,12 +182,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
175182
//===----------------------------------------------------------------------===//
176183

177184
struct PassCrashReproducerGenerator::Impl {
178-
Impl(PassManager::ReproducerStreamFactory &streamFactory,
179-
bool localReproducer)
185+
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
180186
: streamFactory(streamFactory), localReproducer(localReproducer) {}
181187

182188
/// The factory to use when generating a crash reproducer.
183-
PassManager::ReproducerStreamFactory streamFactory;
189+
ReproducerStreamFactory streamFactory;
184190

185191
/// Flag indicating if reproducer generation should be localized to the
186192
/// failing pass.
@@ -198,7 +204,7 @@ struct PassCrashReproducerGenerator::Impl {
198204
};
199205

200206
PassCrashReproducerGenerator::PassCrashReproducerGenerator(
201-
PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
207+
ReproducerStreamFactory &streamFactory, bool localReproducer)
202208
: impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
203209
PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
204210

@@ -382,9 +388,9 @@ struct CrashReproducerInstrumentation : public PassInstrumentation {
382388
//===----------------------------------------------------------------------===//
383389

384390
namespace {
385-
/// This class represents a default instance of PassManager::ReproducerStream
391+
/// This class represents a default instance of mlir::ReproducerStream
386392
/// that is backed by a file.
387-
struct FileReproducerStream : public PassManager::ReproducerStream {
393+
struct FileReproducerStream : public mlir::ReproducerStream {
388394
FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
389395
: outputFile(std::move(outputFile)) {}
390396
~FileReproducerStream() override { outputFile->keep(); }
@@ -418,22 +424,45 @@ LogicalResult PassManager::runWithCrashRecovery(Operation *op,
418424
return passManagerResult;
419425
}
420426

421-
void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
422-
bool genLocalReproducer) {
427+
static ReproducerStreamFactory
428+
makeReproducerStreamFactory(StringRef outputFile) {
423429
// Capture the filename by value in case outputFile is out of scope when
424430
// invoked.
425431
std::string filename = outputFile.str();
426-
enableCrashReproducerGeneration(
427-
[filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
428-
std::unique_ptr<llvm::ToolOutputFile> outputFile =
429-
mlir::openOutputFile(filename, &error);
430-
if (!outputFile) {
431-
error = "Failed to create reproducer stream: " + error;
432-
return nullptr;
433-
}
434-
return std::make_unique<FileReproducerStream>(std::move(outputFile));
435-
},
436-
genLocalReproducer);
432+
return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
433+
std::unique_ptr<llvm::ToolOutputFile> outputFile =
434+
mlir::openOutputFile(filename, &error);
435+
if (!outputFile) {
436+
error = "Failed to create reproducer stream: " + error;
437+
return nullptr;
438+
}
439+
return std::make_unique<FileReproducerStream>(std::move(outputFile));
440+
};
441+
}
442+
443+
void printAsTextualPipeline(
444+
raw_ostream &os, StringRef anchorName,
445+
const llvm::iterator_range<OpPassManager::pass_iterator> &passes);
446+
447+
std::string mlir::makeReproducer(
448+
StringRef anchorName,
449+
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
450+
Operation *op, StringRef outputFile, bool disableThreads,
451+
bool verifyPasses) {
452+
453+
std::string description;
454+
std::string pipelineStr;
455+
llvm::raw_string_ostream passOS(pipelineStr);
456+
::printAsTextualPipeline(passOS, anchorName, passes);
457+
appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
458+
pipelineStr, disableThreads, verifyPasses);
459+
return description;
460+
}
461+
462+
void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
463+
bool genLocalReproducer) {
464+
enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile),
465+
genLocalReproducer);
437466
}
438467

439468
void PassManager::enableCrashReproducerGeneration(

mlir/lib/Pass/PassDetail.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ class OpToOpPassAdaptor
9898

9999
class PassCrashReproducerGenerator {
100100
public:
101-
PassCrashReproducerGenerator(
102-
PassManager::ReproducerStreamFactory &streamFactory,
103-
bool localReproducer);
101+
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory,
102+
bool localReproducer);
104103
~PassCrashReproducerGenerator();
105104

106105
/// Initialize the generator in preparation for reproducer generation. The

mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
151151

152152
static cl::list<std::string> passPlugins(
153153
"load-pass-plugin", cl::desc("Load passes from plugin library"));
154+
155+
static cl::opt<std::string, /*ExternalStorage=*/true>
156+
generateReproducerFile(
157+
"mlir-generate-reproducer",
158+
llvm::cl::desc(
159+
"Generate an mlir reproducer at the provided filename"
160+
" (no crash required)"),
161+
cl::location(generateReproducerFileFlag), cl::init(""),
162+
cl::value_desc("filename"));
163+
154164
/// Set the callback to load a pass plugin.
155165
passPlugins.setCallback([&](const std::string &pluginPath) {
156166
auto plugin = PassPlugin::load(pluginPath);
@@ -384,6 +394,14 @@ performActions(raw_ostream &os,
384394
if (failed(pm.run(*op)))
385395
return failure();
386396

397+
// Generate reproducers if requested
398+
if (!config.getReproducerFilename().empty()) {
399+
StringRef anchorName = pm.getAnyOpAnchorName();
400+
const auto &passes = pm.getPasses();
401+
makeReproducer(anchorName, passes, op.get(),
402+
config.getReproducerFilename());
403+
}
404+
387405
// Print the output.
388406
TimingScope outputTiming = timing.nest("Output");
389407
if (config.shouldEmitBytecode()) {
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(builtin.module(test-module-pass))' --mlir-generate-reproducer=%t -verify-diagnostics
2+
// RUN: cat %t | FileCheck -check-prefix=REPRO %s
3+
4+
module @inner_mod1 {
5+
module @foo {}
6+
}
7+
8+
// REPRO: module @inner_mod1
9+
// REPRO: module @foo {
10+
// REPRO: pipeline: "builtin.module(any(builtin.module(test-module-pass)))"

0 commit comments

Comments
 (0)