Skip to content

Commit dfdf757

Browse files
committed
Split the llvm::ThreadPool into an abstract base class and an implementation
This decouples the public API used to enqueue tasks and wait for completion from the actual implementation, and opens up the possibility for clients to set their own thread pool implementation for the pool.
1 parent 7106389 commit dfdf757

File tree

8 files changed

+112
-64
lines changed

8 files changed

+112
-64
lines changed

llvm/include/llvm/Support/ThreadPool.h

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@ namespace llvm {
3232

3333
class ThreadPoolTaskGroup;
3434

35-
/// A ThreadPool for asynchronous parallel execution on a defined number of
36-
/// threads.
37-
///
38-
/// The pool keeps a vector of threads alive, waiting on a condition variable
39-
/// for some work to become available.
35+
36+
/// This defines the abstract base interface for a ThreadPool allowing
37+
/// asynchronous parallel execution on a defined number of threads.
4038
///
4139
/// It is possible to reuse one thread pool for different groups of tasks
4240
/// by grouping tasks using ThreadPoolTaskGroup. All tasks are processed using
@@ -49,16 +47,31 @@ class ThreadPoolTaskGroup;
4947
/// available threads are used up by tasks waiting for a task that has no thread
5048
/// left to run on (this includes waiting on the returned future). It should be
5149
/// generally safe to wait() for a group as long as groups do not form a cycle.
52-
class ThreadPool {
50+
class ThreadPoolInterface {
51+
// The actual method to enqueue a task to be defined by the concrete implementation.
52+
virtual void asyncEnqueue(std::function<void()> Task, ThreadPoolTaskGroup *Group) = 0;
53+
5354
public:
54-
/// Construct a pool using the hardware strategy \p S for mapping hardware
55-
/// execution resources (threads, cores, CPUs)
56-
/// Defaults to using the maximum execution resources in the system, but
57-
/// accounting for the affinity mask.
58-
ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
5955

60-
/// Blocking destructor: the pool will wait for all the threads to complete.
61-
~ThreadPool();
56+
// Destroying the pool will drain the pending tasks and wait. The current thread may
57+
// participate in the execution of the pending tasks.
58+
virtual ~ThreadPoolInterface();
59+
60+
/// Blocking wait for all the threads to complete and the queue to be empty.
61+
/// It is an error to try to add new tasks while blocking on this call.
62+
/// Calling wait() from a task would deadlock waiting for itself.
63+
virtual void wait() = 0;
64+
65+
/// Blocking wait for only all the threads in the given group to complete.
66+
/// It is possible to wait even inside a task, but waiting (directly or
67+
/// indirectly) on itself will deadlock. If called from a task running on a
68+
/// worker thread, the call may process pending tasks while waiting in order
69+
/// not to waste the thread.
70+
virtual void wait(ThreadPoolTaskGroup &Group) = 0;
71+
72+
// Returns the maximum number of worker this pool can eventually grow to.
73+
virtual unsigned getMaxConcurrency() const = 0;
74+
6275

6376
/// Asynchronous submission of a task to the pool. The returned future can be
6477
/// used to wait for the task to finish and is *non-blocking* on destruction.
@@ -92,27 +105,33 @@ class ThreadPool {
92105
&Group);
93106
}
94107

95-
/// Blocking wait for all the threads to complete and the queue to be empty.
96-
/// It is an error to try to add new tasks while blocking on this call.
97-
/// Calling wait() from a task would deadlock waiting for itself.
98-
void wait();
108+
private:
99109

100-
/// Blocking wait for only all the threads in the given group to complete.
101-
/// It is possible to wait even inside a task, but waiting (directly or
102-
/// indirectly) on itself will deadlock. If called from a task running on a
103-
/// worker thread, the call may process pending tasks while waiting in order
104-
/// not to waste the thread.
105-
void wait(ThreadPoolTaskGroup &Group);
110+
/// Asynchronous submission of a task to the pool. The returned future can be
111+
/// used to wait for the task to finish and is *non-blocking* on destruction.
112+
template <typename ResTy>
113+
std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
114+
ThreadPoolTaskGroup *Group) {
106115

107-
// TODO: misleading legacy name warning!
108-
// Returns the maximum number of worker threads in the pool, not the current
109-
// number of threads!
110-
unsigned getThreadCount() const { return MaxThreadCount; }
116+
#if LLVM_ENABLE_THREADS
117+
/// Wrap the Task in a std::function<void()> that sets the result of the
118+
/// corresponding future.
119+
auto R = createTaskAndFuture(Task);
111120

112-
/// Returns true if the current thread is a worker thread of this thread pool.
113-
bool isWorkerThread() const;
121+
asyncEnqueue(std::move(R.first), Group);
122+
return R.second.share();
123+
124+
#else // LLVM_ENABLE_THREADS Disabled
125+
126+
// Get a Future with launch::deferred execution using std::async
127+
auto Future = std::async(std::launch::deferred, std::move(Task)).share();
128+
// Wrap the future so that both ThreadPool::wait() can operate and the
129+
// returned future can be sync'ed on.
130+
Tasks.emplace_back(std::make_pair([Future]() { Future.get(); }, Group));
131+
return Future;
132+
#endif
133+
}
114134

115-
private:
116135
/// Helpers to create a promise and a callable wrapper of \p Task that sets
117136
/// the result of the promise. Returns the callable and a future to access the
118137
/// result.
@@ -137,44 +156,70 @@ class ThreadPool {
137156
},
138157
std::move(F)};
139158
}
159+
};
160+
161+
/// A ThreadPool implementation using std::threads.
162+
///
163+
/// The pool keeps a vector of threads alive, waiting on a condition variable
164+
/// for some work to become available.
165+
class ThreadPool : public ThreadPoolInterface {
166+
public:
167+
/// Construct a pool using the hardware strategy \p S for mapping hardware
168+
/// execution resources (threads, cores, CPUs)
169+
/// Defaults to using the maximum execution resources in the system, but
170+
/// accounting for the affinity mask.
171+
ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
172+
173+
/// Blocking destructor: the pool will wait for all the threads to complete.
174+
~ThreadPool() override;
175+
176+
177+
/// Blocking wait for all the threads to complete and the queue to be empty.
178+
/// It is an error to try to add new tasks while blocking on this call.
179+
/// Calling wait() from a task would deadlock waiting for itself.
180+
void wait() override;
181+
182+
/// Blocking wait for only all the threads in the given group to complete.
183+
/// It is possible to wait even inside a task, but waiting (directly or
184+
/// indirectly) on itself will deadlock. If called from a task running on a
185+
/// worker thread, the call may process pending tasks while waiting in order
186+
/// not to waste the thread.
187+
void wait(ThreadPoolTaskGroup &Group) override;
188+
189+
// TODO: misleading legacy name warning!
190+
// Returns the maximum number of worker threads in the pool, not the current
191+
// number of threads!
192+
unsigned getThreadCount() const { return MaxThreadCount; }
193+
unsigned getMaxConcurrency() const override { return MaxThreadCount; }
194+
195+
196+
/// Returns true if the current thread is a worker thread of this thread pool.
197+
bool isWorkerThread() const;
198+
199+
private:
140200

141201
/// Returns true if all tasks in the given group have finished (nullptr means
142202
/// all tasks regardless of their group). QueueLock must be locked.
143203
bool workCompletedUnlocked(ThreadPoolTaskGroup *Group) const;
144204

205+
145206
/// Asynchronous submission of a task to the pool. The returned future can be
146207
/// used to wait for the task to finish and is *non-blocking* on destruction.
147-
template <typename ResTy>
148-
std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
149-
ThreadPoolTaskGroup *Group) {
150-
208+
void asyncEnqueue(std::function<void()> Task,
209+
ThreadPoolTaskGroup *Group) override {
151210
#if LLVM_ENABLE_THREADS
152-
/// Wrap the Task in a std::function<void()> that sets the result of the
153-
/// corresponding future.
154-
auto R = createTaskAndFuture(Task);
155-
156211
int requestedThreads;
157212
{
158213
// Lock the queue and push the new task
159214
std::unique_lock<std::mutex> LockGuard(QueueLock);
160215

161216
// Don't allow enqueueing after disabling the pool
162217
assert(EnableFlag && "Queuing a thread during ThreadPool destruction");
163-
Tasks.emplace_back(std::make_pair(std::move(R.first), Group));
218+
Tasks.emplace_back(std::make_pair(std::move(Task), Group));
164219
requestedThreads = ActiveThreads + Tasks.size();
165220
}
166221
QueueCondition.notify_one();
167222
grow(requestedThreads);
168-
return R.second.share();
169-
170-
#else // LLVM_ENABLE_THREADS Disabled
171-
172-
// Get a Future with launch::deferred execution using std::async
173-
auto Future = std::async(std::launch::deferred, std::move(Task)).share();
174-
// Wrap the future so that both ThreadPool::wait() can operate and the
175-
// returned future can be sync'ed on.
176-
Tasks.emplace_back(std::make_pair([Future]() { Future.get(); }, Group));
177-
return Future;
178223
#endif
179224
}
180225

@@ -224,7 +269,7 @@ class ThreadPool {
224269
class ThreadPoolTaskGroup {
225270
public:
226271
/// The ThreadPool argument is the thread pool to forward calls to.
227-
ThreadPoolTaskGroup(ThreadPool &Pool) : Pool(Pool) {}
272+
ThreadPoolTaskGroup(ThreadPoolInterface &Pool) : Pool(Pool) {}
228273

229274
/// Blocking destructor: will wait for all the tasks in the group to complete
230275
/// by calling ThreadPool::wait().
@@ -241,7 +286,7 @@ class ThreadPoolTaskGroup {
241286
void wait() { Pool.wait(*this); }
242287

243288
private:
244-
ThreadPool &Pool;
289+
ThreadPoolInterface &Pool;
245290
};
246291

247292
} // namespace llvm

llvm/lib/Support/ThreadPool.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
using namespace llvm;
2525

26+
ThreadPoolInterface::~ThreadPoolInterface() = default;
27+
2628
#if LLVM_ENABLE_THREADS
2729

2830
// A note on thread groups: Tasks are by default in no group (represented

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <vector>
1818

1919
namespace llvm {
20-
class ThreadPool;
20+
class ThreadPoolInterface;
2121
} // namespace llvm
2222

2323
namespace mlir {
@@ -162,7 +162,7 @@ class MLIRContext {
162162
/// The command line debugging flag `--mlir-disable-threading` will still
163163
/// prevent threading from being enabled and threading won't be enabled after
164164
/// this call in this case.
165-
void setThreadPool(llvm::ThreadPool &pool);
165+
void setThreadPool(llvm::ThreadPoolInterface &pool);
166166

167167
/// Return the number of threads used by the thread pool in this context. The
168168
/// number of computed hardware threads can change over the lifetime of a
@@ -175,7 +175,7 @@ class MLIRContext {
175175
/// multithreading be enabled within the context, and should generally not be
176176
/// used directly. Users should instead prefer the threading utilities within
177177
/// Threading.h.
178-
llvm::ThreadPool &getThreadPool();
178+
llvm::ThreadPoolInterface &getThreadPool();
179179

180180
/// Return true if we should attach the operation to diagnostics emitted via
181181
/// Operation::emit.

mlir/include/mlir/IR/Threading.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
6666
};
6767

6868
// Otherwise, process the elements in parallel.
69-
llvm::ThreadPool &threadPool = context->getThreadPool();
69+
llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
7070
llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71-
size_t numActions = std::min(numElements, threadPool.getThreadCount());
71+
size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
7272
for (unsigned i = 0; i < numActions; ++i)
7373
tasksGroup.async(processFn);
7474
// If the current thread is a worker thread from the pool, then waiting for

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/Visitors.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
3030
#include "mlir/Parser/Parser.h"
31+
#include "llvm/Support/ThreadPool.h"
3132

3233
#include <cstddef>
3334
#include <memory>

mlir/lib/IR/MLIRContext.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ class MLIRContextImpl {
170170
/// It can't be nullptr when multi-threading is enabled. Otherwise if
171171
/// multi-threading is disabled, and the threadpool wasn't externally provided
172172
/// using `setThreadPool`, this will be nullptr.
173-
llvm::ThreadPool *threadPool = nullptr;
173+
llvm::ThreadPoolInterface *threadPool = nullptr;
174174

175175
/// In case where the thread pool is owned by the context, this ensures
176176
/// destruction with the context.
177-
std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
177+
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
178178

179179
/// An allocator used for AbstractAttribute and AbstractType objects.
180180
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -626,7 +626,7 @@ void MLIRContext::disableMultithreading(bool disable) {
626626
}
627627
}
628628

629-
void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
629+
void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
630630
assert(!isMultithreadingEnabled() &&
631631
"expected multi-threading to be disabled when setting a ThreadPool");
632632
impl->threadPool = &pool;
@@ -638,13 +638,13 @@ unsigned MLIRContext::getNumThreads() {
638638
if (isMultithreadingEnabled()) {
639639
assert(impl->threadPool &&
640640
"multi-threading is enabled but threadpool not set");
641-
return impl->threadPool->getThreadCount();
641+
return impl->threadPool->getMaxConcurrency();
642642
}
643643
// No multithreading or active thread pool. Return 1 thread.
644644
return 1;
645645
}
646646

647-
llvm::ThreadPool &MLIRContext::getThreadPool() {
647+
llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
648648
assert(isMultithreadingEnabled() &&
649649
"expected multi-threading to be enabled within the context");
650650
assert(impl->threadPool &&

mlir/lib/Pass/Pass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
748748
// Create the async executors if they haven't been created, or if the main
749749
// pipeline has changed.
750750
if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
751-
asyncExecutors.assign(context->getThreadPool().getThreadCount(), mgrs);
751+
asyncExecutors.assign(context->getThreadPool().getMaxConcurrency(), mgrs);
752752

753753
// This struct represents the information for a single operation to be
754754
// scheduled on a pass manager.

mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ static LogicalResult processBuffer(raw_ostream &os,
429429
std::unique_ptr<MemoryBuffer> ownedBuffer,
430430
const MlirOptMainConfig &config,
431431
DialectRegistry &registry,
432-
llvm::ThreadPool *threadPool) {
432+
llvm::ThreadPoolInterface *threadPool) {
433433
// Tell sourceMgr about this buffer, which is what the parser will pick up.
434434
auto sourceMgr = std::make_shared<SourceMgr>();
435435
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -515,7 +515,7 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
515515
// up into small pieces and checks each independently.
516516
// We use an explicit threadpool to avoid creating and joining/destroying
517517
// threads for each of the split.
518-
ThreadPool *threadPool = nullptr;
518+
ThreadPoolInterface *threadPool = nullptr;
519519

520520
// Create a temporary context for the sake of checking if
521521
// --mlir-disable-threading was passed on the command line.

0 commit comments

Comments
 (0)