Skip to content

[mlir][Pass] Enable the option for reproducer generation without crashing #75421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,27 @@ enum class PassDisplayMode {
Pipeline,
};

/// Streams on which to output crash reproducer.
struct ReproducerStream {
virtual ~ReproducerStream() = default;

/// Description of the reproducer stream.
virtual StringRef description() = 0;

/// Stream on which to output reproducer.
virtual raw_ostream &os() = 0;
};

/// Method type for constructing ReproducerStream.
using ReproducerStreamFactory =
std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;

std::string
makeReproducer(StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, StringRef outputFile, bool disableThreads = false,
bool verifyPasses = false);

/// The main pass manager and pipeline builder.
class PassManager : public OpPassManager {
public:
Expand Down Expand Up @@ -243,21 +264,6 @@ class PassManager : public OpPassManager {
void enableCrashReproducerGeneration(StringRef outputFile,
bool genLocalReproducer = false);

/// Streams on which to output crash reproducer.
struct ReproducerStream {
virtual ~ReproducerStream() = default;

/// Description of the reproducer stream.
virtual StringRef description() = 0;

/// Stream on which to output reproducer.
virtual raw_ostream &os() = 0;
};

/// Method type for constructing ReproducerStream.
using ReproducerStreamFactory =
std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;

/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `factory` is used to construct the streams
/// to write the generated reproducer to. If `genLocalReproducer` is true, the
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class MlirOptMainConfig {
}
bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }

/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }

protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
Expand Down Expand Up @@ -228,6 +231,9 @@ class MlirOptMainConfig {

/// Verify that the input IR round-trips perfectly.
bool verifyRoundtripFlag = false;

/// The reproducer output filename (no crash required).
std::string generateReproducerFileFlag = "";
};

/// This defines the function type used to setup the pass manager. This can be
Expand Down
18 changes: 12 additions & 6 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,16 +382,22 @@ StringRef OpPassManager::getOpAnchorName() const {

/// Prints out the passes of the pass manager as the textual representation
/// of pipelines.
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
os << getOpAnchorName() << "(";
void printAsTextualPipeline(
raw_ostream &os, StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
os << anchorName << "(";
llvm::interleave(
impl->passes,
[&](const std::unique_ptr<Pass> &pass) {
pass->printAsTextualPipeline(os);
},
passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
[&]() { os << ","; });
os << ")";
}
void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
StringRef anchorName = getOpAnchorName();
::printAsTextualPipeline(
os, anchorName,
{MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
}

void OpPassManager::dump() {
llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";
Expand Down
87 changes: 58 additions & 29 deletions mlir/lib/Pass/PassCrashRecovery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace detail {
/// reproducers when a signal is raised, such as a segfault.
struct RecoveryReproducerContext {
RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
PassManager::ReproducerStreamFactory &streamFactory,
ReproducerStreamFactory &streamFactory,
bool verifyPasses);
~RecoveryReproducerContext();

Expand Down Expand Up @@ -67,7 +67,7 @@ struct RecoveryReproducerContext {

/// The factory for the reproducer output stream to use when generating the
/// reproducer.
PassManager::ReproducerStreamFactory &streamFactory;
ReproducerStreamFactory &streamFactory;

/// Various pass manager and context flags.
bool disableThreads;
Expand All @@ -92,7 +92,7 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>

RecoveryReproducerContext::RecoveryReproducerContext(
std::string passPipelineStr, Operation *op,
PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
ReproducerStreamFactory &streamFactory, bool verifyPasses)
: pipelineElements(std::move(passPipelineStr)),
preCrashOperation(op->clone()), streamFactory(streamFactory),
disableThreads(!op->getContext()->isMultithreadingEnabled()),
Expand All @@ -106,22 +106,24 @@ RecoveryReproducerContext::~RecoveryReproducerContext() {
disable();
}

void RecoveryReproducerContext::generate(std::string &description) {
static void appendReproducer(std::string &description, Operation *op,
const ReproducerStreamFactory &factory,
const std::string &pipelineElements,
bool disableThreads, bool verifyPasses) {
llvm::raw_string_ostream descOS(description);

// Try to create a new output stream for this crash reproducer.
std::string error;
std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
std::unique_ptr<ReproducerStream> stream = factory(error);
if (!stream) {
descOS << "failed to create output stream: " << error;
return;
}
descOS << "reproducer generated at `" << stream->description() << "`";

std::string pipeline = (preCrashOperation->getName().getStringRef() + "(" +
pipelineElements + ")")
.str();
AsmState state(preCrashOperation);
std::string pipeline =
(op->getName().getStringRef() + "(" + pipelineElements + ")").str();
AsmState state(op);
state.attachResourcePrinter(
"mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
builder.buildString("pipeline", pipeline);
Expand All @@ -130,7 +132,12 @@ void RecoveryReproducerContext::generate(std::string &description) {
});

// Output the .mlir module.
preCrashOperation->print(stream->os(), state);
op->print(stream->os(), state);
}

void RecoveryReproducerContext::generate(std::string &description) {
appendReproducer(description, preCrashOperation, streamFactory,
pipelineElements, disableThreads, verifyPasses);
}

void RecoveryReproducerContext::disable() {
Expand Down Expand Up @@ -175,12 +182,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
//===----------------------------------------------------------------------===//

struct PassCrashReproducerGenerator::Impl {
Impl(PassManager::ReproducerStreamFactory &streamFactory,
bool localReproducer)
Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
: streamFactory(streamFactory), localReproducer(localReproducer) {}

/// The factory to use when generating a crash reproducer.
PassManager::ReproducerStreamFactory streamFactory;
ReproducerStreamFactory streamFactory;

/// Flag indicating if reproducer generation should be localized to the
/// failing pass.
Expand All @@ -198,7 +204,7 @@ struct PassCrashReproducerGenerator::Impl {
};

PassCrashReproducerGenerator::PassCrashReproducerGenerator(
PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
ReproducerStreamFactory &streamFactory, bool localReproducer)
: impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;

Expand Down Expand Up @@ -382,9 +388,9 @@ struct CrashReproducerInstrumentation : public PassInstrumentation {
//===----------------------------------------------------------------------===//

namespace {
/// This class represents a default instance of PassManager::ReproducerStream
/// This class represents a default instance of mlir::ReproducerStream
/// that is backed by a file.
struct FileReproducerStream : public PassManager::ReproducerStream {
struct FileReproducerStream : public mlir::ReproducerStream {
FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
: outputFile(std::move(outputFile)) {}
~FileReproducerStream() override { outputFile->keep(); }
Expand Down Expand Up @@ -418,22 +424,45 @@ LogicalResult PassManager::runWithCrashRecovery(Operation *op,
return passManagerResult;
}

void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
bool genLocalReproducer) {
static ReproducerStreamFactory
makeReproducerStreamFactory(StringRef outputFile) {
// Capture the filename by value in case outputFile is out of scope when
// invoked.
std::string filename = outputFile.str();
enableCrashReproducerGeneration(
[filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
std::unique_ptr<llvm::ToolOutputFile> outputFile =
mlir::openOutputFile(filename, &error);
if (!outputFile) {
error = "Failed to create reproducer stream: " + error;
return nullptr;
}
return std::make_unique<FileReproducerStream>(std::move(outputFile));
},
genLocalReproducer);
return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
std::unique_ptr<llvm::ToolOutputFile> outputFile =
mlir::openOutputFile(filename, &error);
if (!outputFile) {
error = "Failed to create reproducer stream: " + error;
return nullptr;
}
return std::make_unique<FileReproducerStream>(std::move(outputFile));
};
}

void printAsTextualPipeline(
raw_ostream &os, StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes);

std::string mlir::makeReproducer(
StringRef anchorName,
const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
Operation *op, StringRef outputFile, bool disableThreads,
bool verifyPasses) {

std::string description;
std::string pipelineStr;
llvm::raw_string_ostream passOS(pipelineStr);
::printAsTextualPipeline(passOS, anchorName, passes);
appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
pipelineStr, disableThreads, verifyPasses);
return description;
}

void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
bool genLocalReproducer) {
enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile),
genLocalReproducer);
}

void PassManager::enableCrashReproducerGeneration(
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Pass/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ class OpToOpPassAdaptor

class PassCrashReproducerGenerator {
public:
PassCrashReproducerGenerator(
PassManager::ReproducerStreamFactory &streamFactory,
bool localReproducer);
PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory,
bool localReproducer);
~PassCrashReproducerGenerator();

/// Initialize the generator in preparation for reproducer generation. The
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {

static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));

static cl::opt<std::string, /*ExternalStorage=*/true>
generateReproducerFile(
"mlir-generate-reproducer",
llvm::cl::desc(
"Generate an mlir reproducer at the provided filename"
" (no crash required)"),
cl::location(generateReproducerFileFlag), cl::init(""),
cl::value_desc("filename"));

/// Set the callback to load a pass plugin.
passPlugins.setCallback([&](const std::string &pluginPath) {
auto plugin = PassPlugin::load(pluginPath);
Expand Down Expand Up @@ -384,6 +394,14 @@ performActions(raw_ostream &os,
if (failed(pm.run(*op)))
return failure();

// Generate reproducers if requested
if (!config.getReproducerFilename().empty()) {
StringRef anchorName = pm.getAnyOpAnchorName();
const auto &passes = pm.getPasses();
makeReproducer(anchorName, passes, op.get(),
config.getReproducerFilename());
}

// Print the output.
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Pass/crashless-reproducer.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(builtin.module(test-module-pass))' --mlir-generate-reproducer=%t -verify-diagnostics
// RUN: cat %t | FileCheck -check-prefix=REPRO %s

module @inner_mod1 {
module @foo {}
}

// REPRO: module @inner_mod1
// REPRO: module @foo {
// REPRO: pipeline: "builtin.module(any(builtin.module(test-module-pass)))"