Skip to content

Commit 2caf555

Browse files
author
sergei
authored
[SYCL] Make initialization of host task thread pool thread-safe (#4985)
1 parent 2fa6376 commit 2caf555

File tree

8 files changed

+62
-46
lines changed

8 files changed

+62
-46
lines changed

sycl/source/detail/config.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ CONFIG(SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MIN_DEVICE_IMAGE_SIZE)
3535
CONFIG(SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE, 16, __SYCL_CACHE_MAX_DEVICE_IMAGE_SIZE)
3636
CONFIG(INTEL_ENABLE_OFFLOAD_ANNOTATIONS, 1, __SYCL_INTEL_ENABLE_OFFLOAD_ANNOTATIONS)
3737
CONFIG(SYCL_ENABLE_DEFAULT_CONTEXTS, 1, __SYCL_ENABLE_DEFAULT_CONTEXTS)
38+
CONFIG(SYCL_QUEUE_THREAD_POOL_SIZE, 4, __SYCL_QUEUE_THREAD_POOL_SIZE)

sycl/source/detail/config.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <algorithm>
1919
#include <array>
2020
#include <cstdlib>
21+
#include <mutex>
2122
#include <string>
2223
#include <utility>
2324

@@ -320,6 +321,39 @@ template <> class SYCLConfig<SYCL_ENABLE_DEFAULT_CONTEXTS> {
320321
}
321322
};
322323

324+
template <> class SYCLConfig<SYCL_QUEUE_THREAD_POOL_SIZE> {
325+
using BaseT = SYCLConfigBase<SYCL_QUEUE_THREAD_POOL_SIZE>;
326+
327+
public:
328+
static int get() {
329+
static int Value = [] {
330+
const char *ValueStr = BaseT::getRawValue();
331+
332+
int Result = 1;
333+
334+
if (ValueStr)
335+
try {
336+
Result = std::stoi(ValueStr);
337+
} catch (...) {
338+
throw invalid_parameter_error(
339+
"Invalid value for SYCL_QUEUE_THREAD_POOL_SIZE environment "
340+
"variable: value should be a number",
341+
PI_INVALID_VALUE);
342+
}
343+
344+
if (Result < 1)
345+
throw invalid_parameter_error(
346+
"Invalid value for SYCL_QUEUE_THREAD_POOL_SIZE environment "
347+
"variable: value should be larger than zero",
348+
PI_INVALID_VALUE);
349+
350+
return Result;
351+
}();
352+
353+
return Value;
354+
}
355+
};
356+
323357
} // namespace detail
324358
} // namespace sycl
325359
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/detail/global_handler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
#include <CL/sycl/detail/device_filter.hpp>
1010
#include <CL/sycl/detail/pi.hpp>
1111
#include <CL/sycl/detail/spinlock.hpp>
12+
#include <detail/config.hpp>
1213
#include <detail/global_handler.hpp>
1314
#include <detail/platform_impl.hpp>
1415
#include <detail/plugin.hpp>
1516
#include <detail/program_manager/program_manager.hpp>
1617
#include <detail/scheduler/scheduler.hpp>
18+
#include <detail/thread_pool.hpp>
1719
#include <detail/xpti_registry.hpp>
1820

1921
#ifdef _WIN32
@@ -89,6 +91,13 @@ std::mutex &GlobalHandler::getHandlerExtendedMembersMutex() {
8991
return getOrCreate(MHandlerExtendedMembersMutex);
9092
}
9193

94+
ThreadPool &GlobalHandler::getHostTaskThreadPool() {
95+
int Size = SYCLConfig<SYCL_QUEUE_THREAD_POOL_SIZE>::get();
96+
ThreadPool &TP = getOrCreate(MHostTaskThreadPool, Size);
97+
98+
return TP;
99+
}
100+
92101
void releaseDefaultContexts() {
93102
// Release shared-pointers to SYCL objects.
94103
#ifndef _WIN32
@@ -112,6 +121,11 @@ void GlobalHandler::registerDefaultContextReleaseHandler() {
112121
}
113122

114123
void shutdown() {
124+
// Ensure neither host task is working so that no default context is accessed
125+
// upon its release
126+
if (GlobalHandler::instance().MHostTaskThreadPool.Inst)
127+
GlobalHandler::instance().MHostTaskThreadPool.Inst->finishAndWait();
128+
115129
// If default contexts are requested after the first default contexts have
116130
// been released there may be a new default context. These must be released
117131
// prior to closing the plugins.

sycl/source/detail/global_handler.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Sync;
2525
class plugin;
2626
class device_filter_list;
2727
class XPTIRegistry;
28+
class ThreadPool;
2829

2930
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3031
using ContextImplPtr = std::shared_ptr<context_impl>;
@@ -67,6 +68,7 @@ class GlobalHandler {
6768
device_filter_list &getDeviceFilterList(const std::string &InitValue);
6869
XPTIRegistry &getXPTIRegistry();
6970
std::mutex &getHandlerExtendedMembersMutex();
71+
ThreadPool &getHostTaskThreadPool();
7072

7173
static void registerDefaultContextReleaseHandler();
7274

@@ -101,6 +103,8 @@ class GlobalHandler {
101103
InstWithLock<XPTIRegistry> MXPTIRegistry;
102104
// The mutex for synchronizing accesses to handlers extended members
103105
InstWithLock<std::mutex> MHandlerExtendedMembersMutex;
106+
// Thread pool for host task and event callbacks execution
107+
InstWithLock<ThreadPool> MHostTaskThreadPool;
104108
};
105109
} // namespace detail
106110
} // namespace sycl

sycl/source/detail/queue_impl.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -322,30 +322,6 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
322322
#endif
323323
}
324324

325-
void queue_impl::initHostTaskAndEventCallbackThreadPool() {
326-
if (MHostTaskThreadPool)
327-
return;
328-
329-
int Size = 1;
330-
331-
if (const char *Val = std::getenv("SYCL_QUEUE_THREAD_POOL_SIZE"))
332-
try {
333-
Size = std::stoi(Val);
334-
} catch (...) {
335-
throw invalid_parameter_error(
336-
"Invalid value for SYCL_QUEUE_THREAD_POOL_SIZE environment variable",
337-
PI_INVALID_VALUE);
338-
}
339-
340-
if (Size < 1)
341-
throw invalid_parameter_error(
342-
"Invalid value for SYCL_QUEUE_THREAD_POOL_SIZE environment variable",
343-
PI_INVALID_VALUE);
344-
345-
MHostTaskThreadPool.reset(new ThreadPool(Size));
346-
MHostTaskThreadPool->start();
347-
}
348-
349325
pi_native_handle queue_impl::getNative() const {
350326
const detail::plugin &Plugin = getPlugin();
351327
if (Plugin.getBackend() == backend::opencl)

sycl/source/detail/queue_impl.hpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <detail/context_impl.hpp>
2525
#include <detail/device_impl.hpp>
2626
#include <detail/event_impl.hpp>
27+
#include <detail/global_handler.hpp>
2728
#include <detail/kernel_impl.hpp>
2829
#include <detail/plugin.hpp>
2930
#include <detail/scheduler/scheduler.hpp>
@@ -397,16 +398,7 @@ class queue_impl {
397398
}
398399

399400
ThreadPool &getThreadPool() {
400-
if (!MHostTaskThreadPool)
401-
initHostTaskAndEventCallbackThreadPool();
402-
403-
return *MHostTaskThreadPool;
404-
}
405-
406-
void stopThreadPool() {
407-
if (MHostTaskThreadPool) {
408-
MHostTaskThreadPool->finishAndWait();
409-
}
401+
return GlobalHandler::instance().getHostTaskThreadPool();
410402
}
411403

412404
/// Gets the native handle of the SYCL queue.
@@ -495,8 +487,6 @@ class queue_impl {
495487
void instrumentationEpilog(void *TelementryEvent, std::string &Name,
496488
int32_t StreamID, uint64_t IId);
497489

498-
void initHostTaskAndEventCallbackThreadPool();
499-
500490
/// queue_impl.addEvent tracks events with weak pointers
501491
/// but some events have no other owners. addSharedEvent()
502492
/// follows events with a shared pointer.
@@ -535,10 +525,6 @@ class queue_impl {
535525
// Assume OOO support by default.
536526
bool MSupportOOO = true;
537527

538-
// Thread pool for host task and event callbacks execution.
539-
// The thread pool is instantiated upon the very first call to getThreadPool()
540-
std::unique_ptr<ThreadPool> MHostTaskThreadPool;
541-
542528
// Buffer to store assert failure descriptor
543529
buffer<AssertHappened, 1> MAssertHappenedBuffer;
544530

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ Scheduler::Scheduler() {
364364
}
365365

366366
Scheduler::~Scheduler() {
367-
DefaultHostQueue->stopThreadPool();
368367
// By specification there are several possible sync points: buffer
369368
// destruction, wait() method of a queue or event. Stream doesn't introduce
370369
// any synchronization point. It is guaranteed that stream is flushed and

sycl/source/detail/thread_pool.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ class ThreadPool {
5151
}
5252
}
5353

54-
public:
55-
ThreadPool(unsigned int ThreadCount = 1) : MThreadCount(ThreadCount) {}
56-
57-
~ThreadPool() { finishAndWait(); }
58-
5954
void start() {
6055
MLaunchedThreads.reserve(MThreadCount);
6156

@@ -65,6 +60,13 @@ class ThreadPool {
6560
MLaunchedThreads.emplace_back([this] { worker(); });
6661
}
6762

63+
public:
64+
ThreadPool(unsigned int ThreadCount = 1) : MThreadCount(ThreadCount) {
65+
start();
66+
}
67+
68+
~ThreadPool() { finishAndWait(); }
69+
6870
void finishAndWait() {
6971
MStop.store(true);
7072

0 commit comments

Comments
 (0)