Skip to content

Commit 379a094

Browse files
[SYCL] Fix static destruction order issue in OpenCL extension fptr cache (#9254)
OpenCL plugin uses several static maps for caching extension function pointers retrieved from the backend. If a user application has a static variable that indirectly calls one of those functions in its destructor, the corresponding map might have already been destroyed. This patch fixes the problem by tying the lifetime of those maps to piTearDown.
1 parent 29e629e commit 379a094

File tree

1 file changed

+125
-71
lines changed

1 file changed

+125
-71
lines changed

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 125 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <limits>
2828
#include <map>
2929
#include <memory>
30+
#include <mutex>
3031
#include <sstream>
3132
#include <string>
3233
#include <string_view>
@@ -186,16 +187,70 @@ static cl_int checkDeviceExtensions(cl_device_id dev,
186187
return ret_err;
187188
}
188189

190+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
191+
cl_device_id device, cl_program program, const char *FuncName,
192+
cl_ulong *ret_ptr);
193+
194+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueWriteGlobalVariable_fn)(
195+
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t,
196+
const void *, cl_uint, const cl_event *, cl_event *);
197+
198+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueReadGlobalVariable_fn)(
199+
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t, void *,
200+
cl_uint, const cl_event *, cl_event *);
201+
202+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
203+
cl_program program, cl_uint spec_id, size_t spec_size,
204+
const void *spec_value);
205+
206+
template <typename T> struct FuncPtrCache {
207+
std::map<pi_context, T> Map;
208+
std::mutex Mutex;
209+
};
210+
211+
// FIXME: There's currently no mechanism for cleaning up this cache, meaning
212+
// that it is invalidated whenever a context is destroyed. This could lead to
213+
// reusing an invalid function pointer if another context happends to have the
214+
// same native handle.
215+
struct ExtFuncPtrCacheT {
216+
FuncPtrCache<clHostMemAllocINTEL_fn> clHostMemAllocINTELCache;
217+
FuncPtrCache<clDeviceMemAllocINTEL_fn> clDeviceMemAllocINTELCache;
218+
FuncPtrCache<clSharedMemAllocINTEL_fn> clSharedMemAllocINTELCache;
219+
FuncPtrCache<clGetDeviceFunctionPointer_fn> clGetDeviceFunctionPointerCache;
220+
FuncPtrCache<clCreateBufferWithPropertiesINTEL_fn>
221+
clCreateBufferWithPropertiesINTELCache;
222+
FuncPtrCache<clMemBlockingFreeINTEL_fn> clMemBlockingFreeINTELCache;
223+
FuncPtrCache<clSetKernelArgMemPointerINTEL_fn>
224+
clSetKernelArgMemPointerINTELCache;
225+
FuncPtrCache<clEnqueueMemsetINTEL_fn> clEnqueueMemsetINTELCache;
226+
FuncPtrCache<clEnqueueMemcpyINTEL_fn> clEnqueueMemcpyINTELCache;
227+
FuncPtrCache<clGetMemAllocInfoINTEL_fn> clGetMemAllocInfoINTELCache;
228+
FuncPtrCache<clEnqueueWriteGlobalVariable_fn>
229+
clEnqueueWriteGlobalVariableCache;
230+
FuncPtrCache<clEnqueueReadGlobalVariable_fn> clEnqueueReadGlobalVariableCache;
231+
FuncPtrCache<clEnqueueReadHostPipeINTEL_fn> clEnqueueReadHostPipeINTELCache;
232+
FuncPtrCache<clEnqueueWriteHostPipeINTEL_fn> clEnqueueWriteHostPipeINTELCache;
233+
FuncPtrCache<clSetProgramSpecializationConstant_fn>
234+
clSetProgramSpecializationConstantCache;
235+
};
236+
// A raw pointer is used here since the lifetime of this map has to be tied to
237+
// piTeardown to avoid issues with static destruction order (a user application
238+
// might have static objects that indirectly access this cache in their
239+
// destructor).
240+
static ExtFuncPtrCacheT *ExtFuncPtrCache = new ExtFuncPtrCacheT();
241+
189242
// USM helper function to get an extension function pointer
190-
template <const char *FuncName, typename T>
191-
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
243+
template <typename T>
244+
static pi_result getExtFuncFromContext(pi_context context,
245+
FuncPtrCache<T> &FPtrCache,
246+
const char *FuncName, T *fptr) {
192247
// TODO
193248
// Potentially redo caching as PI interface changes.
194-
thread_local static std::map<pi_context, T> FuncPtrs;
195-
196249
// if cached, return cached FuncPtr
197-
auto It = FuncPtrs.find(context);
198-
if (It != FuncPtrs.end()) {
250+
std::lock_guard<std::mutex> CacheLock{FPtrCache.Mutex};
251+
std::map<pi_context, T> &FPtrMap = FPtrCache.Map;
252+
auto It = FPtrMap.find(context);
253+
if (It != FPtrMap.end()) {
199254
auto F = It->second;
200255
// if cached that extension is not available return nullptr and
201256
// PI_ERROR_INVALID_VALUE
@@ -234,12 +289,12 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
234289

235290
if (!FuncPtr) {
236291
// Cache that the extension is not available
237-
FuncPtrs[context] = nullptr;
292+
FPtrMap[context] = nullptr;
238293
return PI_ERROR_INVALID_VALUE;
239294
}
240295

241296
*fptr = FuncPtr;
242-
FuncPtrs[context] = FuncPtr;
297+
FPtrMap[context] = FuncPtr;
243298

244299
return cast<pi_result>(ret_err);
245300
}
@@ -262,24 +317,27 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
262317
return cast<pi_result>(CLErr);
263318
}
264319

265-
getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
266-
cast<pi_context>(CLContext), &HFunc);
320+
getExtFuncFromContext<clHostMemAllocINTEL_fn>(
321+
cast<pi_context>(CLContext), ExtFuncPtrCache->clHostMemAllocINTELCache,
322+
clHostMemAllocName, &HFunc);
267323
if (HFunc) {
268324
clSetKernelExecInfo(cast<cl_kernel>(kernel),
269325
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
270326
sizeof(cl_bool), &TrueVal);
271327
}
272328

273-
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
274-
cast<pi_context>(CLContext), &DFunc);
329+
getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
330+
cast<pi_context>(CLContext), ExtFuncPtrCache->clDeviceMemAllocINTELCache,
331+
clDeviceMemAllocName, &DFunc);
275332
if (DFunc) {
276333
clSetKernelExecInfo(cast<cl_kernel>(kernel),
277334
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
278335
sizeof(cl_bool), &TrueVal);
279336
}
280337

281-
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
282-
cast<pi_context>(CLContext), &SFunc);
338+
getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
339+
cast<pi_context>(CLContext), ExtFuncPtrCache->clSharedMemAllocINTELCache,
340+
clSharedMemAllocName, &SFunc);
283341
if (SFunc) {
284342
clSetKernelExecInfo(cast<cl_kernel>(kernel),
285343
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
@@ -1100,9 +1158,6 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
11001158
return false;
11011159
}
11021160

1103-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
1104-
cl_device_id device, cl_program program, const char *FuncName,
1105-
cl_ulong *ret_ptr);
11061161
pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
11071162
const char *func_name,
11081163
pi_uint64 *function_pointer_ret) {
@@ -1116,9 +1171,10 @@ pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
11161171
return cast<pi_result>(ret_err);
11171172

11181173
clGetDeviceFunctionPointer_fn FuncT = nullptr;
1119-
ret_err = getExtFuncFromContext<clGetDeviceFunctionPointerName,
1120-
clGetDeviceFunctionPointer_fn>(
1121-
cast<pi_context>(CLContext), &FuncT);
1174+
ret_err = getExtFuncFromContext<clGetDeviceFunctionPointer_fn>(
1175+
cast<pi_context>(CLContext),
1176+
ExtFuncPtrCache->clGetDeviceFunctionPointerCache,
1177+
clGetDeviceFunctionPointerName, &FuncT);
11221178

11231179
pi_result pi_ret_err = PI_SUCCESS;
11241180

@@ -1235,9 +1291,9 @@ pi_result piMemBufferCreate(pi_context context, pi_mem_flags flags, size_t size,
12351291
// ignore unsupported
12361292
clCreateBufferWithPropertiesINTEL_fn FuncPtr = nullptr;
12371293
// First we need to look up the function pointer
1238-
ret_err = getExtFuncFromContext<clCreateBufferWithPropertiesName,
1239-
clCreateBufferWithPropertiesINTEL_fn>(
1240-
context, &FuncPtr);
1294+
ret_err = getExtFuncFromContext<clCreateBufferWithPropertiesINTEL_fn>(
1295+
context, ExtFuncPtrCache->clCreateBufferWithPropertiesINTELCache,
1296+
clCreateBufferWithPropertiesName, &FuncPtr);
12411297
if (FuncPtr) {
12421298
*ret_mem = cast<pi_mem>(FuncPtr(cast<cl_context>(context), properties,
12431299
cast<cl_mem_flags>(flags), size, host_ptr,
@@ -1516,8 +1572,9 @@ pi_result piextUSMHostAlloc(void **result_ptr, pi_context context,
15161572

15171573
// First we need to look up the function pointer
15181574
clHostMemAllocINTEL_fn FuncPtr = nullptr;
1519-
RetVal = getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
1520-
context, &FuncPtr);
1575+
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
1576+
context, ExtFuncPtrCache->clHostMemAllocINTELCache, clHostMemAllocName,
1577+
&FuncPtr);
15211578

15221579
if (FuncPtr) {
15231580
Ptr = FuncPtr(cast<cl_context>(context),
@@ -1553,9 +1610,9 @@ pi_result piextUSMDeviceAlloc(void **result_ptr, pi_context context,
15531610

15541611
// First we need to look up the function pointer
15551612
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
1556-
RetVal =
1557-
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
1558-
context, &FuncPtr);
1613+
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
1614+
context, ExtFuncPtrCache->clDeviceMemAllocINTELCache,
1615+
clDeviceMemAllocName, &FuncPtr);
15591616

15601617
if (FuncPtr) {
15611618
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
@@ -1591,9 +1648,9 @@ pi_result piextUSMSharedAlloc(void **result_ptr, pi_context context,
15911648

15921649
// First we need to look up the function pointer
15931650
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
1594-
RetVal =
1595-
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
1596-
context, &FuncPtr);
1651+
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
1652+
context, ExtFuncPtrCache->clSharedMemAllocINTELCache,
1653+
clSharedMemAllocName, &FuncPtr);
15971654

15981655
if (FuncPtr) {
15991656
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
@@ -1619,9 +1676,9 @@ pi_result piextUSMFree(pi_context context, void *ptr) {
16191676
clMemBlockingFreeINTEL_fn FuncPtr = nullptr;
16201677

16211678
pi_result RetVal = PI_ERROR_INVALID_OPERATION;
1622-
RetVal =
1623-
getExtFuncFromContext<clMemBlockingFreeName, clMemBlockingFreeINTEL_fn>(
1624-
context, &FuncPtr);
1679+
RetVal = getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
1680+
context, ExtFuncPtrCache->clMemBlockingFreeINTELCache,
1681+
clMemBlockingFreeName, &FuncPtr);
16251682

16261683
if (FuncPtr) {
16271684
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
@@ -1652,9 +1709,10 @@ pi_result piextKernelSetArgPointer(pi_kernel kernel, pi_uint32 arg_index,
16521709
}
16531710

16541711
clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
1655-
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerName,
1656-
clSetKernelArgMemPointerINTEL_fn>(
1657-
cast<pi_context>(CLContext), &FuncPtr);
1712+
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
1713+
cast<pi_context>(CLContext),
1714+
ExtFuncPtrCache->clSetKernelArgMemPointerINTELCache,
1715+
clSetKernelArgMemPointerName, &FuncPtr);
16581716

16591717
if (FuncPtr) {
16601718
// OpenCL passes pointers by value not by reference
@@ -1693,9 +1751,9 @@ pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr, pi_int32 value,
16931751
}
16941752

16951753
clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
1696-
pi_result RetVal =
1697-
getExtFuncFromContext<clEnqueueMemsetName, clEnqueueMemsetINTEL_fn>(
1698-
cast<pi_context>(CLContext), &FuncPtr);
1754+
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
1755+
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemsetINTELCache,
1756+
clEnqueueMemsetName, &FuncPtr);
16991757

17001758
if (FuncPtr) {
17011759
RetVal = cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), ptr, value,
@@ -1733,9 +1791,9 @@ pi_result piextUSMEnqueueMemcpy(pi_queue queue, pi_bool blocking, void *dst_ptr,
17331791
}
17341792

17351793
clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
1736-
pi_result RetVal =
1737-
getExtFuncFromContext<clEnqueueMemcpyName, clEnqueueMemcpyINTEL_fn>(
1738-
cast<pi_context>(CLContext), &FuncPtr);
1794+
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
1795+
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
1796+
clEnqueueMemcpyName, &FuncPtr);
17391797

17401798
if (FuncPtr) {
17411799
RetVal = cast<pi_result>(
@@ -1959,9 +2017,9 @@ pi_result piextUSMGetMemAllocInfo(pi_context context, const void *ptr,
19592017
size_t *param_value_size_ret) {
19602018

19612019
clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
1962-
pi_result RetVal =
1963-
getExtFuncFromContext<clGetMemAllocInfoName, clGetMemAllocInfoINTEL_fn>(
1964-
context, &FuncPtr);
2020+
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
2021+
context, ExtFuncPtrCache->clGetMemAllocInfoINTELCache,
2022+
clGetMemAllocInfoName, &FuncPtr);
19652023

19662024
if (FuncPtr) {
19672025
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,
@@ -1972,14 +2030,6 @@ pi_result piextUSMGetMemAllocInfo(pi_context context, const void *ptr,
19722030
return RetVal;
19732031
}
19742032

1975-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueWriteGlobalVariable_fn)(
1976-
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t,
1977-
const void *, cl_uint, const cl_event *, cl_event *);
1978-
1979-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueReadGlobalVariable_fn)(
1980-
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t, void *,
1981-
cl_uint, const cl_event *, cl_event *);
1982-
19832033
/// API for writing data from host to a device global variable.
19842034
///
19852035
/// \param queue is the queue
@@ -2007,8 +2057,9 @@ pi_result piextEnqueueDeviceGlobalVariableWrite(
20072057
return cast<pi_result>(Res);
20082058

20092059
clEnqueueWriteGlobalVariable_fn F = nullptr;
2010-
Res = getExtFuncFromContext<clEnqueueWriteGlobalVariableName, decltype(F)>(
2011-
cast<pi_context>(Ctx), &F);
2060+
Res = getExtFuncFromContext<decltype(F)>(
2061+
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
2062+
clEnqueueWriteGlobalVariableName, &F);
20122063

20132064
if (!F || Res != CL_SUCCESS)
20142065
return PI_ERROR_INVALID_OPERATION;
@@ -2044,8 +2095,9 @@ pi_result piextEnqueueDeviceGlobalVariableRead(
20442095
return cast<pi_result>(Res);
20452096

20462097
clEnqueueReadGlobalVariable_fn F = nullptr;
2047-
Res = getExtFuncFromContext<clEnqueueReadGlobalVariableName, decltype(F)>(
2048-
cast<pi_context>(Ctx), &F);
2098+
Res = getExtFuncFromContext<decltype(F)>(
2099+
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
2100+
clEnqueueReadGlobalVariableName, &F);
20492101

20502102
if (!F || Res != CL_SUCCESS)
20512103
return PI_ERROR_INVALID_OPERATION;
@@ -2070,9 +2122,10 @@ pi_result piextEnqueueReadHostPipe(pi_queue queue, pi_program program,
20702122
}
20712123

20722124
clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
2073-
pi_result RetVal = getExtFuncFromContext<clEnqueueReadHostPipeName,
2074-
clEnqueueReadHostPipeINTEL_fn>(
2075-
cast<pi_context>(CLContext), &FuncPtr);
2125+
pi_result RetVal = getExtFuncFromContext<clEnqueueReadHostPipeINTEL_fn>(
2126+
cast<pi_context>(CLContext),
2127+
ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
2128+
clEnqueueReadHostPipeName, &FuncPtr);
20762129

20772130
if (FuncPtr) {
20782131
RetVal = cast<pi_result>(FuncPtr(
@@ -2099,9 +2152,10 @@ pi_result piextEnqueueWriteHostPipe(pi_queue queue, pi_program program,
20992152
}
21002153

21012154
clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
2102-
pi_result RetVal = getExtFuncFromContext<clEnqueueWriteHostPipeName,
2103-
clEnqueueWriteHostPipeINTEL_fn>(
2104-
cast<pi_context>(CLContext), &FuncPtr);
2155+
pi_result RetVal = getExtFuncFromContext<clEnqueueWriteHostPipeINTEL_fn>(
2156+
cast<pi_context>(CLContext),
2157+
ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
2158+
clEnqueueWriteHostPipeName, &FuncPtr);
21052159

21062160
if (FuncPtr) {
21072161
RetVal = cast<pi_result>(FuncPtr(
@@ -2136,10 +2190,6 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
21362190
}
21372191
}
21382192

2139-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
2140-
cl_program program, cl_uint spec_id, size_t spec_size,
2141-
const void *spec_value);
2142-
21432193
pi_result piextProgramSetSpecializationConstant(pi_program prog,
21442194
pi_uint32 spec_id,
21452195
size_t spec_size,
@@ -2154,8 +2204,10 @@ pi_result piextProgramSetSpecializationConstant(pi_program prog,
21542204
return cast<pi_result>(Res);
21552205

21562206
clSetProgramSpecializationConstant_fn F = nullptr;
2157-
Res = getExtFuncFromContext<clSetProgramSpecializationConstantName,
2158-
decltype(F)>(cast<pi_context>(Ctx), &F);
2207+
Res = getExtFuncFromContext<decltype(F)>(
2208+
cast<pi_context>(Ctx),
2209+
ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
2210+
clSetProgramSpecializationConstantName, &F);
21592211

21602212
if (!F || Res != CL_SUCCESS)
21612213
return PI_ERROR_INVALID_OPERATION;
@@ -2223,9 +2275,11 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
22232275
// called safely. But this is not transitive. If the PI plugin in turn
22242276
// dynamically loaded a different DLL, that may have been unloaded.
22252277
// TODO: add a global variable lifetime management code here (see
2226-
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
2278+
// pi_level_zero.cpp for reference).
22272279
pi_result piTearDown(void *PluginParameter) {
22282280
(void)PluginParameter;
2281+
delete ExtFuncPtrCache;
2282+
ExtFuncPtrCache = nullptr;
22292283
return PI_SUCCESS;
22302284
}
22312285

0 commit comments

Comments
 (0)