Skip to content

Commit 4e73f17

Browse files
committed
Split the llvm::ThreadPool into an abstract base class and an implementation
This decouples the public API used to enqueue tasks and wait for completion from the actual implementation, and opens up the possibility for clients to set their own thread pool implementation for the pool. https://discourse.llvm.org/t/construct-threadpool-from-vector-of-existing-threads/76883
1 parent 1d5e3b2 commit 4e73f17

File tree

7 files changed

+106
-65
lines changed

7 files changed

+106
-65
lines changed

llvm/include/llvm/Support/ThreadPool.h

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,8 @@ namespace llvm {
3232

3333
class ThreadPoolTaskGroup;
3434

35-
/// A ThreadPool for asynchronous parallel execution on a defined number of
36-
/// threads.
37-
///
38-
/// The pool keeps a vector of threads alive, waiting on a condition variable
39-
/// for some work to become available.
35+
/// This defines the abstract base interface for a ThreadPool allowing
36+
/// asynchronous parallel execution on a defined number of threads.
4037
///
4138
/// It is possible to reuse one thread pool for different groups of tasks
4239
/// by grouping tasks using ThreadPoolTaskGroup. All tasks are processed using
@@ -49,16 +46,31 @@ class ThreadPoolTaskGroup;
4946
/// available threads are used up by tasks waiting for a task that has no thread
5047
/// left to run on (this includes waiting on the returned future). It should be
5148
/// generally safe to wait() for a group as long as groups do not form a cycle.
52-
class ThreadPool {
49+
class ThreadPoolInterface {
50+
/// The actual method to enqueue a task to be defined by the concrete
51+
/// implementation.
52+
virtual void asyncEnqueue(std::function<void()> Task,
53+
ThreadPoolTaskGroup *Group) = 0;
54+
5355
public:
54-
/// Construct a pool using the hardware strategy \p S for mapping hardware
55-
/// execution resources (threads, cores, CPUs)
56-
/// Defaults to using the maximum execution resources in the system, but
57-
/// accounting for the affinity mask.
58-
ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
56+
/// Destroying the pool will drain the pending tasks and wait. The current
57+
/// thread may participate in the execution of the pending tasks.
58+
virtual ~ThreadPoolInterface();
5959

60-
/// Blocking destructor: the pool will wait for all the threads to complete.
61-
~ThreadPool();
60+
/// Blocking wait for all the threads to complete and the queue to be empty.
61+
/// It is an error to try to add new tasks while blocking on this call.
62+
/// Calling wait() from a task would deadlock waiting for itself.
63+
virtual void wait() = 0;
64+
65+
/// Blocking wait for only all the threads in the given group to complete.
66+
/// It is possible to wait even inside a task, but waiting (directly or
67+
/// indirectly) on itself will deadlock. If called from a task running on a
68+
/// worker thread, the call may process pending tasks while waiting in order
69+
/// not to waste the thread.
70+
virtual void wait(ThreadPoolTaskGroup &Group) = 0;
71+
72+
/// Returns the maximum number of worker this pool can eventually grow to.
73+
virtual unsigned getMaxConcurrency() const = 0;
6274

6375
/// Asynchronous submission of a task to the pool. The returned future can be
6476
/// used to wait for the task to finish and is *non-blocking* on destruction.
@@ -92,30 +104,32 @@ class ThreadPool {
92104
&Group);
93105
}
94106

95-
/// Blocking wait for all the threads to complete and the queue to be empty.
96-
/// It is an error to try to add new tasks while blocking on this call.
97-
/// Calling wait() from a task would deadlock waiting for itself.
98-
void wait();
107+
private:
108+
/// Asynchronous submission of a task to the pool. The returned future can be
109+
/// used to wait for the task to finish and is *non-blocking* on destruction.
110+
template <typename ResTy>
111+
std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
112+
ThreadPoolTaskGroup *Group) {
99113

100-
/// Blocking wait for only all the threads in the given group to complete.
101-
/// It is possible to wait even inside a task, but waiting (directly or
102-
/// indirectly) on itself will deadlock. If called from a task running on a
103-
/// worker thread, the call may process pending tasks while waiting in order
104-
/// not to waste the thread.
105-
void wait(ThreadPoolTaskGroup &Group);
114+
#if LLVM_ENABLE_THREADS
115+
/// Wrap the Task in a std::function<void()> that sets the result of the
116+
/// corresponding future.
117+
auto R = createTaskAndFuture(Task);
106118

107-
// Returns the maximum number of worker threads in the pool, not the current
108-
// number of threads!
109-
unsigned getMaxConcurrency() const { return MaxThreadCount; }
119+
asyncEnqueue(std::move(R.first), Group);
120+
return R.second.share();
110121

111-
// TODO: misleading legacy name warning!
112-
LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
113-
unsigned getThreadCount() const { return MaxThreadCount; }
122+
#else // LLVM_ENABLE_THREADS Disabled
114123

115-
/// Returns true if the current thread is a worker thread of this thread pool.
116-
bool isWorkerThread() const;
124+
// Get a Future with launch::deferred execution using std::async
125+
auto Future = std::async(std::launch::deferred, std::move(Task)).share();
126+
// Wrap the future so that both ThreadPool::wait() can operate and the
127+
// returned future can be sync'ed on.
128+
asyncEnqueue([Future]() { Future.get(); }, Group);
129+
return Future;
130+
#endif
131+
}
117132

118-
private:
119133
/// Helpers to create a promise and a callable wrapper of \p Task that sets
120134
/// the result of the promise. Returns the callable and a future to access the
121135
/// result.
@@ -140,50 +154,74 @@ class ThreadPool {
140154
},
141155
std::move(F)};
142156
}
157+
};
158+
159+
/// A ThreadPool implementation using std::threads.
160+
///
161+
/// The pool keeps a vector of threads alive, waiting on a condition variable
162+
/// for some work to become available.
163+
class ThreadPool : public ThreadPoolInterface {
164+
public:
165+
/// Construct a pool using the hardware strategy \p S for mapping hardware
166+
/// execution resources (threads, cores, CPUs)
167+
/// Defaults to using the maximum execution resources in the system, but
168+
/// accounting for the affinity mask.
169+
ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
170+
171+
/// Blocking destructor: the pool will wait for all the threads to complete.
172+
~ThreadPool() override;
173+
174+
/// Blocking wait for all the threads to complete and the queue to be empty.
175+
/// It is an error to try to add new tasks while blocking on this call.
176+
/// Calling wait() from a task would deadlock waiting for itself.
177+
void wait() override;
178+
179+
/// Blocking wait for only all the threads in the given group to complete.
180+
/// It is possible to wait even inside a task, but waiting (directly or
181+
/// indirectly) on itself will deadlock. If called from a task running on a
182+
/// worker thread, the call may process pending tasks while waiting in order
183+
/// not to waste the thread.
184+
void wait(ThreadPoolTaskGroup &Group) override;
143185

186+
/// Returns the maximum number of worker threads in the pool, not the current
187+
/// number of threads!
188+
unsigned getMaxConcurrency() const override { return MaxThreadCount; }
189+
190+
// TODO: Remove, misleading legacy name warning!
191+
LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
192+
unsigned getThreadCount() const { return MaxThreadCount; }
193+
194+
/// Returns true if the current thread is a worker thread of this thread pool.
195+
bool isWorkerThread() const;
196+
197+
private:
144198
/// Returns true if all tasks in the given group have finished (nullptr means
145199
/// all tasks regardless of their group). QueueLock must be locked.
146200
bool workCompletedUnlocked(ThreadPoolTaskGroup *Group) const;
147201

148202
/// Asynchronous submission of a task to the pool. The returned future can be
149203
/// used to wait for the task to finish and is *non-blocking* on destruction.
150-
template <typename ResTy>
151-
std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
152-
ThreadPoolTaskGroup *Group) {
153-
204+
void asyncEnqueue(std::function<void()> Task,
205+
ThreadPoolTaskGroup *Group) override {
154206
#if LLVM_ENABLE_THREADS
155-
/// Wrap the Task in a std::function<void()> that sets the result of the
156-
/// corresponding future.
157-
auto R = createTaskAndFuture(Task);
158-
159207
int requestedThreads;
160208
{
161209
// Lock the queue and push the new task
162210
std::unique_lock<std::mutex> LockGuard(QueueLock);
163211

164212
// Don't allow enqueueing after disabling the pool
165213
assert(EnableFlag && "Queuing a thread during ThreadPool destruction");
166-
Tasks.emplace_back(std::make_pair(std::move(R.first), Group));
214+
Tasks.emplace_back(std::make_pair(std::move(Task), Group));
167215
requestedThreads = ActiveThreads + Tasks.size();
168216
}
169217
QueueCondition.notify_one();
170218
grow(requestedThreads);
171-
return R.second.share();
172-
173-
#else // LLVM_ENABLE_THREADS Disabled
174-
175-
// Get a Future with launch::deferred execution using std::async
176-
auto Future = std::async(std::launch::deferred, std::move(Task)).share();
177-
// Wrap the future so that both ThreadPool::wait() can operate and the
178-
// returned future can be sync'ed on.
179-
Tasks.emplace_back(std::make_pair([Future]() { Future.get(); }, Group));
180-
return Future;
181219
#endif
182220
}
183221

184222
#if LLVM_ENABLE_THREADS
185-
// Grow to ensure that we have at least `requested` Threads, but do not go
186-
// over MaxThreadCount.
223+
/// Grow to ensure that we have at least `requested` Threads, but do not go
224+
/// over MaxThreadCount.
187225
void grow(int requested);
188226

189227
void processTasks(ThreadPoolTaskGroup *WaitingForGroup);
@@ -227,7 +265,7 @@ class ThreadPool {
227265
class ThreadPoolTaskGroup {
228266
public:
229267
/// The ThreadPool argument is the thread pool to forward calls to.
230-
ThreadPoolTaskGroup(ThreadPool &Pool) : Pool(Pool) {}
268+
ThreadPoolTaskGroup(ThreadPoolInterface &Pool) : Pool(Pool) {}
231269

232270
/// Blocking destructor: will wait for all the tasks in the group to complete
233271
/// by calling ThreadPool::wait().
@@ -244,7 +282,7 @@ class ThreadPoolTaskGroup {
244282
void wait() { Pool.wait(*this); }
245283

246284
private:
247-
ThreadPool &Pool;
285+
ThreadPoolInterface &Pool;
248286
};
249287

250288
} // namespace llvm

llvm/lib/Support/ThreadPool.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
using namespace llvm;
2525

26+
ThreadPoolInterface::~ThreadPoolInterface() = default;
27+
2628
#if LLVM_ENABLE_THREADS
2729

2830
// A note on thread groups: Tasks are by default in no group (represented

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <vector>
1818

1919
namespace llvm {
20-
class ThreadPool;
20+
class ThreadPoolInterface;
2121
} // namespace llvm
2222

2323
namespace mlir {
@@ -162,7 +162,7 @@ class MLIRContext {
162162
/// The command line debugging flag `--mlir-disable-threading` will still
163163
/// prevent threading from being enabled and threading won't be enabled after
164164
/// this call in this case.
165-
void setThreadPool(llvm::ThreadPool &pool);
165+
void setThreadPool(llvm::ThreadPoolInterface &pool);
166166

167167
/// Return the number of threads used by the thread pool in this context. The
168168
/// number of computed hardware threads can change over the lifetime of a
@@ -175,7 +175,7 @@ class MLIRContext {
175175
/// multithreading be enabled within the context, and should generally not be
176176
/// used directly. Users should instead prefer the threading utilities within
177177
/// Threading.h.
178-
llvm::ThreadPool &getThreadPool();
178+
llvm::ThreadPoolInterface &getThreadPool();
179179

180180
/// Return true if we should attach the operation to diagnostics emitted via
181181
/// Operation::emit.

mlir/include/mlir/IR/Threading.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
6666
};
6767

6868
// Otherwise, process the elements in parallel.
69-
llvm::ThreadPool &threadPool = context->getThreadPool();
69+
llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
7070
llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
7171
size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
7272
for (unsigned i = 0; i < numActions; ++i)

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/IR/Visitors.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
3030
#include "mlir/Parser/Parser.h"
31+
#include "llvm/Support/ThreadPool.h"
3132

3233
#include <cstddef>
3334
#include <memory>

mlir/lib/IR/MLIRContext.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ class MLIRContextImpl {
170170
/// It can't be nullptr when multi-threading is enabled. Otherwise if
171171
/// multi-threading is disabled, and the threadpool wasn't externally provided
172172
/// using `setThreadPool`, this will be nullptr.
173-
llvm::ThreadPool *threadPool = nullptr;
173+
llvm::ThreadPoolInterface *threadPool = nullptr;
174174

175175
/// In case where the thread pool is owned by the context, this ensures
176176
/// destruction with the context.
177-
std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
177+
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
178178

179179
/// An allocator used for AbstractAttribute and AbstractType objects.
180180
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -626,7 +626,7 @@ void MLIRContext::disableMultithreading(bool disable) {
626626
}
627627
}
628628

629-
void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
629+
void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
630630
assert(!isMultithreadingEnabled() &&
631631
"expected multi-threading to be disabled when setting a ThreadPool");
632632
impl->threadPool = &pool;
@@ -644,7 +644,7 @@ unsigned MLIRContext::getNumThreads() {
644644
return 1;
645645
}
646646

647-
llvm::ThreadPool &MLIRContext::getThreadPool() {
647+
llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
648648
assert(isMultithreadingEnabled() &&
649649
"expected multi-threading to be enabled within the context");
650650
assert(impl->threadPool &&

mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ static LogicalResult processBuffer(raw_ostream &os,
429429
std::unique_ptr<MemoryBuffer> ownedBuffer,
430430
const MlirOptMainConfig &config,
431431
DialectRegistry &registry,
432-
llvm::ThreadPool *threadPool) {
432+
llvm::ThreadPoolInterface *threadPool) {
433433
// Tell sourceMgr about this buffer, which is what the parser will pick up.
434434
auto sourceMgr = std::make_shared<SourceMgr>();
435435
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -515,7 +515,7 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
515515
// up into small pieces and checks each independently.
516516
// We use an explicit threadpool to avoid creating and joining/destroying
517517
// threads for each of the split.
518-
ThreadPool *threadPool = nullptr;
518+
ThreadPoolInterface *threadPool = nullptr;
519519

520520
// Create a temporary context for the sake of checking if
521521
// --mlir-disable-threading was passed on the command line.

0 commit comments

Comments
 (0)