Skip to content

Commit a32300a

Browse files
committed
Make the --mlir-disable-threading command line option overrides the C++ API usage
This seems in-line with the intent and how we build tools around it. Update the description for the flag accordingly. Also use an injected thread pool in MLIROptMain, now we will create threads up-front and reuse them across split buffers. Differential Revision: https://reviews.llvm.org/D109802
1 parent 500d4c4 commit a32300a

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class MLIRContext {
129129
bool isMultithreadingEnabled();
130130

131131
/// Set the flag specifying if multi-threading is disabled by the context.
132+
/// The command line debugging flag `--mlir-disable-threading` is overriding
133+
/// this call and making it a no-op!
132134
void disableMultithreading(bool disable = true);
133135
void enableMultithreading(bool enable = true) {
134136
disableMultithreading(!enable);
@@ -140,6 +142,9 @@ class MLIRContext {
140142
/// decoupling the lifetime of the threads from the contexts. The thread pool
141143
/// must outlive the context. Multi-threading will be enabled as part of this
142144
/// method.
145+
/// The command line debugging flag `--mlir-disable-threading` will still
146+
/// prevent threading from being enabled and threading won't be enabled after
147+
/// this call in this case.
143148
void setThreadPool(llvm::ThreadPool &pool);
144149

145150
/// Return the thread pool used by this context. This method requires that

mlir/lib/IR/MLIRContext.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ namespace {
5757
struct MLIRContextOptions {
5858
llvm::cl::opt<bool> disableThreading{
5959
"mlir-disable-threading",
60-
llvm::cl::desc("Disabling multi-threading within MLIR")};
60+
llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
61+
"further call to MLIRContext::enableMultiThreading()")};
6162

6263
llvm::cl::opt<bool> printOpOnDiagnostic{
6364
"mlir-print-op-on-diagnostic",
@@ -74,6 +75,14 @@ struct MLIRContextOptions {
7475

7576
static llvm::ManagedStatic<MLIRContextOptions> clOptions;
7677

78+
static bool isThreadingGloballyDisabled() {
79+
#if LLVM_ENABLE_THREADS != 0
80+
return clOptions.isConstructed() && clOptions->disableThreading;
81+
#else
82+
return true;
83+
#endif
84+
}
85+
7786
/// Register a set of useful command-line options that can be used to configure
7887
/// various flags within the MLIRContext. These flags are used when constructing
7988
/// an MLIR context for initialization.
@@ -362,10 +371,10 @@ MLIRContext::MLIRContext(Threading setting)
362371
: MLIRContext(DialectRegistry(), setting) {}
363372

364373
MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
365-
: impl(new MLIRContextImpl(setting == Threading::ENABLED)) {
374+
: impl(new MLIRContextImpl(setting == Threading::ENABLED &&
375+
!isThreadingGloballyDisabled())) {
366376
// Initialize values based on the command line flags if they were provided.
367377
if (clOptions.isConstructed()) {
368-
disableMultithreading(clOptions->disableThreading);
369378
printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
370379
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
371380
}
@@ -582,6 +591,11 @@ bool MLIRContext::isMultithreadingEnabled() {
582591

583592
/// Set the flag specifying if multi-threading is disabled by the context.
584593
void MLIRContext::disableMultithreading(bool disable) {
594+
// This API can be overridden by the global debugging flag
595+
// --mlir-disable-threading
596+
if (isThreadingGloballyDisabled())
597+
return;
598+
585599
impl->threadingIsEnabled = !disable;
586600

587601
// Update the threading mode for each of the uniquers.

mlir/lib/Support/MlirOptMain.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/Support/Regex.h"
3333
#include "llvm/Support/SourceMgr.h"
3434
#include "llvm/Support/StringSaver.h"
35+
#include "llvm/Support/ThreadPool.h"
3536
#include "llvm/Support/ToolOutputFile.h"
3637

3738
using namespace mlir;
@@ -93,19 +94,22 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
9394

9495
/// Parses the memory buffer. If successfully, run a series of passes against
9596
/// it and print the result.
96-
static LogicalResult processBuffer(raw_ostream &os,
97-
std::unique_ptr<MemoryBuffer> ownedBuffer,
98-
bool verifyDiagnostics, bool verifyPasses,
99-
bool allowUnregisteredDialects,
100-
bool preloadDialectsInContext,
101-
const PassPipelineCLParser &passPipeline,
102-
DialectRegistry &registry) {
97+
static LogicalResult
98+
processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
99+
bool verifyDiagnostics, bool verifyPasses,
100+
bool allowUnregisteredDialects, bool preloadDialectsInContext,
101+
const PassPipelineCLParser &passPipeline,
102+
DialectRegistry &registry, llvm::ThreadPool &threadPool) {
103103
// Tell sourceMgr about this buffer, which is what the parser will pick up.
104104
SourceMgr sourceMgr;
105105
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
106106

107+
// Create a context just for the current buffer. Disable threading on creation
108+
// since we'll inject the thread-pool separately.
109+
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
110+
context.setThreadPool(threadPool);
111+
107112
// Parse the input file.
108-
MLIRContext context(registry);
109113
if (preloadDialectsInContext)
110114
context.loadAllAvailableDialects();
111115
context.allowUnregisteredDialects(allowUnregisteredDialects);
@@ -143,20 +147,24 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
143147
bool preloadDialectsInContext) {
144148
// The split-input-file mode is a very specific mode that slices the file
145149
// up into small pieces and checks each independently.
150+
// We use an explicit threadpool to avoid creating and joining/destroying
151+
// threads for each of the split.
152+
llvm::ThreadPool threadPool;
146153
if (splitInputFile)
147154
return splitAndProcessBuffer(
148155
std::move(buffer),
149156
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
150157
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
151158
verifyPasses, allowUnregisteredDialects,
152-
preloadDialectsInContext, passPipeline,
153-
registry);
159+
preloadDialectsInContext, passPipeline, registry,
160+
threadPool);
154161
},
155162
outputStream);
156163

157164
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
158165
verifyPasses, allowUnregisteredDialects,
159-
preloadDialectsInContext, passPipeline, registry);
166+
preloadDialectsInContext, passPipeline, registry,
167+
threadPool);
160168
}
161169

162170
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,

0 commit comments

Comments
 (0)