Skip to content

Commit 6bbbd7b

Browse files
committed
Update MLIRContext to allow injecting an external ThreadPool (NFC)
The context can be created with threading disabled, to avoid creating a thread pool that may be destroyed when injecting another one later. Differential Revision: https://reviews.llvm.org/D105302
1 parent 64a0241 commit 6bbbd7b

File tree

2 files changed

+72
-15
lines changed

2 files changed

+72
-15
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,27 @@ class StorageUniquer;
3838
/// a very generic name ("Context") and because it is uncommon for clients to
3939
/// interact with it.
4040
///
41+
/// The context wrap some multi-threading facilities, and in particular by
42+
/// default it will implicitly create a thread pool.
43+
/// This can be undesirable if multiple context exists at the same time or if a
44+
/// process will be long-lived and create and destroy contexts.
45+
/// To control better thread spawning, an externally owned ThreadPool can be
46+
/// injected in the context. For example:
47+
///
48+
/// llvm::ThreadPool myThreadPool;
49+
/// while (auto *request = nextCompilationRequests()) {
50+
/// MLIRContext ctx(registry, MLIRContext::Threading::DISABLED);
51+
/// ctx.setThreadPool(myThreadPool);
52+
/// processRequest(request, cxt);
53+
/// }
54+
///
4155
class MLIRContext {
4256
public:
57+
enum class Threading { DISABLED, ENABLED };
4358
/// Create a new Context.
44-
explicit MLIRContext();
45-
explicit MLIRContext(const DialectRegistry &registry);
59+
explicit MLIRContext(Threading multithreading = Threading::ENABLED);
60+
explicit MLIRContext(const DialectRegistry &registry,
61+
Threading multithreading = Threading::ENABLED);
4662
~MLIRContext();
4763

4864
/// Return information about all IR dialects loaded in the context.
@@ -118,7 +134,15 @@ class MLIRContext {
118134
disableMultithreading(!enable);
119135
}
120136

121-
/// Return the thread pool owned by this context. This method requires that
137+
/// Set a new thread pool to be used in this context. This method requires
138+
/// that multithreading is disabled for this context prior to the call. This
139+
/// allows to share a thread pool across multiple contexts, as well as
140+
/// decoupling the lifetime of the threads from the contexts. The thread pool
141+
/// must outlive the context. Multi-threading will be enabled as part of this
142+
/// method.
143+
void setThreadPool(llvm::ThreadPool &pool);
144+
145+
/// Return the thread pool used by this context. This method requires that
122146
/// multithreading be enabled within the context, and should generally not be
123147
/// used directly. Users should instead prefer the threading utilities within
124148
/// Threading.h.

mlir/lib/IR/MLIRContext.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,15 @@ class MLIRContextImpl {
261261
// Other
262262
//===--------------------------------------------------------------------===//
263263

264-
/// The thread pool to use when processing MLIR tasks in parallel.
265-
llvm::Optional<llvm::ThreadPool> threadPool;
264+
/// This points to the ThreadPool used when processing MLIR tasks in parallel.
265+
/// It can't be nullptr when multi-threading is enabled. Otherwise if
266+
/// multi-threading is disabled, and the threadpool wasn't externally provided
267+
/// using `setThreadPool`, this will be nullptr.
268+
llvm::ThreadPool *threadPool = nullptr;
269+
270+
/// In case where the thread pool is owned by the context, this ensures
271+
/// destruction with the context.
272+
std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
266273

267274
/// This is a list of dialects that are created referring to this context.
268275
/// The MLIRContext owns the objects.
@@ -334,9 +341,13 @@ class MLIRContextImpl {
334341
StringAttr emptyStringAttr;
335342

336343
public:
337-
MLIRContextImpl() : identifiers(identifierAllocator) {
338-
if (threadingIsEnabled)
339-
threadPool.emplace();
344+
MLIRContextImpl(bool threadingIsEnabled)
345+
: threadingIsEnabled(threadingIsEnabled),
346+
identifiers(identifierAllocator) {
347+
if (threadingIsEnabled) {
348+
ownedThreadPool = std::make_unique<llvm::ThreadPool>();
349+
threadPool = ownedThreadPool.get();
350+
}
340351
}
341352
~MLIRContextImpl() {
342353
for (auto typeMapping : registeredTypes)
@@ -347,10 +358,11 @@ class MLIRContextImpl {
347358
};
348359
} // end namespace mlir
349360

350-
MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {}
361+
MLIRContext::MLIRContext(Threading setting)
362+
: MLIRContext(DialectRegistry(), setting) {}
351363

352-
MLIRContext::MLIRContext(const DialectRegistry &registry)
353-
: impl(new MLIRContextImpl) {
364+
MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
365+
: impl(new MLIRContextImpl(setting == Threading::ENABLED)) {
354366
// Initialize values based on the command line flags if they were provided.
355367
if (clOptions.isConstructed()) {
356368
disableMultithreading(clOptions->disableThreading);
@@ -579,15 +591,36 @@ void MLIRContext::disableMultithreading(bool disable) {
579591

580592
// Destroy thread pool (stop all threads) if it is no longer needed, or create
581593
// a new one if multithreading was re-enabled.
582-
if (!impl->threadingIsEnabled)
583-
impl->threadPool.reset();
584-
else if (!impl->threadPool.hasValue())
585-
impl->threadPool.emplace();
594+
if (disable) {
595+
// If the thread pool is owned, explicitly set it to nullptr to avoid
596+
// keeping a dangling pointer around. If the thread pool is externally
597+
// owned, we don't do anything.
598+
if (impl->ownedThreadPool) {
599+
assert(impl->threadPool);
600+
impl->threadPool = nullptr;
601+
impl->ownedThreadPool.reset();
602+
}
603+
} else if (!impl->threadPool) {
604+
// The thread pool isn't externally provided.
605+
assert(!impl->ownedThreadPool);
606+
impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
607+
impl->threadPool = impl->ownedThreadPool.get();
608+
}
609+
}
610+
611+
void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
612+
assert(!isMultithreadingEnabled() &&
613+
"expected multi-threading to be disabled when setting a ThreadPool");
614+
impl->threadPool = &pool;
615+
impl->ownedThreadPool.reset();
616+
enableMultithreading();
586617
}
587618

588619
llvm::ThreadPool &MLIRContext::getThreadPool() {
589620
assert(isMultithreadingEnabled() &&
590621
"expected multi-threading to be enabled within the context");
622+
assert(impl->threadPool &&
623+
"multi-threading is enabled but threadpool not set");
591624
return *impl->threadPool;
592625
}
593626

0 commit comments

Comments
 (0)