@@ -261,8 +261,15 @@ class MLIRContextImpl {
261
261
// Other
262
262
// ===--------------------------------------------------------------------===//
263
263
264
- // / The thread pool to use when processing MLIR tasks in parallel.
265
- llvm::Optional<llvm::ThreadPool> threadPool;
264
+ // / This points to the ThreadPool used when processing MLIR tasks in parallel.
265
+ // / It can't be nullptr when multi-threading is enabled. Otherwise if
266
+ // / multi-threading is disabled, and the threadpool wasn't externally provided
267
+ // / using `setThreadPool`, this will be nullptr.
268
+ llvm::ThreadPool *threadPool = nullptr ;
269
+
270
+ // / In case where the thread pool is owned by the context, this ensures
271
+ // / destruction with the context.
272
+ std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
266
273
267
274
// / This is a list of dialects that are created referring to this context.
268
275
// / The MLIRContext owns the objects.
@@ -334,9 +341,13 @@ class MLIRContextImpl {
334
341
StringAttr emptyStringAttr;
335
342
336
343
public:
337
- MLIRContextImpl () : identifiers(identifierAllocator) {
338
- if (threadingIsEnabled)
339
- threadPool.emplace ();
344
+ MLIRContextImpl (bool threadingIsEnabled)
345
+ : threadingIsEnabled(threadingIsEnabled),
346
+ identifiers (identifierAllocator) {
347
+ if (threadingIsEnabled) {
348
+ ownedThreadPool = std::make_unique<llvm::ThreadPool>();
349
+ threadPool = ownedThreadPool.get ();
350
+ }
340
351
}
341
352
~MLIRContextImpl () {
342
353
for (auto typeMapping : registeredTypes)
@@ -347,10 +358,11 @@ class MLIRContextImpl {
347
358
};
348
359
} // end namespace mlir
349
360
350
- MLIRContext::MLIRContext () : MLIRContext(DialectRegistry()) {}
361
+ MLIRContext::MLIRContext (Threading setting)
362
+ : MLIRContext(DialectRegistry(), setting) {}
351
363
352
- MLIRContext::MLIRContext (const DialectRegistry ®istry)
353
- : impl(new MLIRContextImpl) {
364
+ MLIRContext::MLIRContext (const DialectRegistry ®istry, Threading setting )
365
+ : impl(new MLIRContextImpl(setting == Threading::ENABLED) ) {
354
366
// Initialize values based on the command line flags if they were provided.
355
367
if (clOptions.isConstructed ()) {
356
368
disableMultithreading (clOptions->disableThreading );
@@ -579,15 +591,36 @@ void MLIRContext::disableMultithreading(bool disable) {
579
591
580
592
// Destroy thread pool (stop all threads) if it is no longer needed, or create
581
593
// a new one if multithreading was re-enabled.
582
- if (!impl->threadingIsEnabled )
583
- impl->threadPool .reset ();
584
- else if (!impl->threadPool .hasValue ())
585
- impl->threadPool .emplace ();
594
+ if (disable) {
595
+ // If the thread pool is owned, explicitly set it to nullptr to avoid
596
+ // keeping a dangling pointer around. If the thread pool is externally
597
+ // owned, we don't do anything.
598
+ if (impl->ownedThreadPool ) {
599
+ assert (impl->threadPool );
600
+ impl->threadPool = nullptr ;
601
+ impl->ownedThreadPool .reset ();
602
+ }
603
+ } else if (!impl->threadPool ) {
604
+ // The thread pool isn't externally provided.
605
+ assert (!impl->ownedThreadPool );
606
+ impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
607
+ impl->threadPool = impl->ownedThreadPool .get ();
608
+ }
609
+ }
610
+
611
+ void MLIRContext::setThreadPool (llvm::ThreadPool &pool) {
612
+ assert (!isMultithreadingEnabled () &&
613
+ " expected multi-threading to be disabled when setting a ThreadPool" );
614
+ impl->threadPool = &pool;
615
+ impl->ownedThreadPool .reset ();
616
+ enableMultithreading ();
586
617
}
587
618
588
619
llvm::ThreadPool &MLIRContext::getThreadPool () {
589
620
assert (isMultithreadingEnabled () &&
590
621
" expected multi-threading to be enabled within the context" );
622
+ assert (impl->threadPool &&
623
+ " multi-threading is enabled but threadpool not set" );
591
624
return *impl->threadPool ;
592
625
}
593
626
0 commit comments