Skip to content

Commit 2eed402

Browse files
author
sergei
authored
[SYCL] Clear extensions functions cache upon context release (#5282)
This is to eliminate reuse of invalid cached values after context being released. Signed-off-by: Sergey Kanaev <[email protected]>
1 parent 6b2635e commit 2eed402

File tree

3 files changed

+121
-14
lines changed

3 files changed

+121
-14
lines changed

sycl/plugins/opencl/ext_functions.inc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef _EXT_FUNCTION_INTEL
2+
#error Undefined _EXT_FUNCTION_INTEL macro expansion
3+
#endif
4+
5+
#ifndef _EXT_FUNCTION
6+
#error Undefined _EXT_FUNCTION macro expansion
7+
#endif
8+
9+
_EXT_FUNCTION_INTEL(clHostMemAlloc)
10+
_EXT_FUNCTION_INTEL(clDeviceMemAlloc)
11+
_EXT_FUNCTION_INTEL(clSharedMemAlloc)
12+
_EXT_FUNCTION_INTEL(clCreateBufferWithProperties)
13+
_EXT_FUNCTION_INTEL(clMemBlockingFree)
14+
_EXT_FUNCTION_INTEL(clMemFree)
15+
_EXT_FUNCTION_INTEL(clSetKernelArgMemPointer)
16+
_EXT_FUNCTION_INTEL(clEnqueueMemset)
17+
_EXT_FUNCTION_INTEL(clEnqueueMemcpy)
18+
_EXT_FUNCTION_INTEL(clGetMemAllocInfo)
19+
_EXT_FUNCTION(clGetDeviceFunctionPointer)
20+
_EXT_FUNCTION(clSetProgramSpecializationConstant)

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <iostream>
2626
#include <limits>
2727
#include <map>
28+
#include <memory>
29+
#include <mutex>
2830
#include <sstream>
2931
#include <string>
3032
#include <vector>
@@ -71,19 +73,93 @@ CONSTFIX char clGetDeviceFunctionPointerName[] =
7173

7274
#undef CONSTFIX
7375

76+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
77+
cl_device_id device, cl_program program, const char *FuncName,
78+
cl_ulong *ret_ptr);
79+
80+
typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
81+
cl_program program, cl_uint spec_id, size_t spec_size,
82+
const void *spec_value);
83+
84+
struct ExtFuncsPerContextT;
85+
86+
namespace detail {
87+
template <const char *FuncName, typename FuncT>
88+
std::pair<FuncT &, bool &> get(ExtFuncsPerContextT &);
89+
} // namespace detail
90+
91+
struct ExtFuncsPerContextT {
92+
#define _EXT_FUNCTION_INTEL(t_pfx) \
93+
t_pfx##INTEL_fn t_pfx##Func = nullptr; \
94+
bool t_pfx##Initialized = false;
95+
96+
#define _EXT_FUNCTION(t_pfx) \
97+
t_pfx##_fn t_pfx##Func = nullptr; \
98+
bool t_pfx##Initialized = false;
99+
100+
#include "ext_functions.inc"
101+
102+
#undef _EXT_FUNCTION
103+
#undef _EXT_FUNCTION_INTEL
104+
105+
std::mutex Mtx;
106+
107+
template <const char *FuncName, typename FuncT>
108+
std::pair<FuncT &, bool &> get() {
109+
return detail::get<FuncName, FuncT>(*this);
110+
}
111+
};
112+
113+
namespace detail {
114+
115+
#define _EXT_FUNCTION_COMMON(t_pfx, t_pfx_suff) \
116+
template <> \
117+
std::pair<t_pfx_suff##_fn &, bool &> get<t_pfx##Name, t_pfx_suff##_fn>( \
118+
ExtFuncsPerContextT & Funcs) { \
119+
using FPtrT = t_pfx_suff##_fn; \
120+
std::pair<FPtrT &, bool &> Ret{Funcs.t_pfx##Func, \
121+
Funcs.t_pfx##Initialized}; \
122+
return Ret; \
123+
}
124+
#define _EXT_FUNCTION_INTEL(t_pfx) _EXT_FUNCTION_COMMON(t_pfx, t_pfx##INTEL)
125+
#define _EXT_FUNCTION(t_pfx) _EXT_FUNCTION_COMMON(t_pfx, t_pfx)
126+
127+
#include "ext_functions.inc"
128+
129+
#undef _EXT_FUNCTION
130+
#undef _EXT_FUNCTION_INTEL
131+
#undef _EXT_FUNCTION_COMMON
132+
} // namespace detail
133+
134+
struct ExtFuncsCachesT {
135+
std::map<pi_context, ExtFuncsPerContextT> Caches;
136+
std::mutex Mtx;
137+
};
138+
139+
ExtFuncsCachesT *ExtFuncsCaches = nullptr;
140+
74141
// USM helper function to get an extension function pointer
75142
template <const char *FuncName, typename T>
76143
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
77144
// TODO
78145
// Potentially redo caching as PI interface changes.
79-
thread_local static std::map<pi_context, T> FuncPtrs;
146+
ExtFuncsPerContextT *PerContext = nullptr;
147+
{
148+
assert(ExtFuncsCaches);
149+
std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx};
150+
151+
PerContext = &ExtFuncsCaches->Caches[context];
152+
}
153+
154+
std::lock_guard<std::mutex> Lock{PerContext->Mtx};
155+
std::pair<T &, bool &> FuncInitialized = PerContext->get<FuncName, T>();
80156

81157
// if cached, return cached FuncPtr
82-
if (auto F = FuncPtrs[context]) {
158+
if (FuncInitialized.second) {
83159
// if cached that extension is not available return nullptr and
84160
// PI_INVALID_VALUE
85-
*fptr = F;
86-
return F ? PI_SUCCESS : PI_INVALID_VALUE;
161+
*fptr = FuncInitialized.first;
162+
return *fptr ? PI_SUCCESS : PI_INVALID_VALUE;
87163
}
88164

89165
cl_uint deviceCount;
@@ -115,14 +191,17 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
115191
T FuncPtr =
116192
(T)clGetExtensionFunctionAddressForPlatform(curPlatform, FuncName);
117193

194+
// We're about to store the cached value. Mark this cache entry initialized.
195+
FuncInitialized.second = true;
196+
118197
if (!FuncPtr) {
119198
// Cache that the extension is not available
120-
FuncPtrs[context] = nullptr;
199+
FuncInitialized.first = nullptr;
121200
return PI_INVALID_VALUE;
122201
}
123202

203+
FuncInitialized.first = FuncPtr;
124204
*fptr = FuncPtr;
125-
FuncPtrs[context] = FuncPtr;
126205

127206
return cast<pi_result>(ret_err);
128207
}
@@ -561,9 +640,6 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
561640
return false;
562641
}
563642

564-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
565-
cl_device_id device, cl_program program, const char *FuncName,
566-
cl_ulong *ret_ptr);
567643
pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
568644
const char *func_name,
569645
pi_uint64 *function_pointer_ret) {
@@ -1304,10 +1380,6 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
13041380
}
13051381
}
13061382

1307-
typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
1308-
cl_program program, cl_uint spec_id, size_t spec_size,
1309-
const void *spec_value);
1310-
13111383
pi_result piextProgramSetSpecializationConstant(pi_program prog,
13121384
pi_uint32 spec_id,
13131385
size_t spec_size,
@@ -1383,9 +1455,21 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
13831455
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
13841456
pi_result piTearDown(void *PluginParameter) {
13851457
(void)PluginParameter;
1458+
delete ExtFuncsCaches;
1459+
ExtFuncsCaches = nullptr;
13861460
return PI_SUCCESS;
13871461
}
13881462

1463+
pi_result piContextRelease(pi_context Context) {
1464+
{
1465+
std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx};
1466+
1467+
ExtFuncsCaches->Caches.erase(Context);
1468+
}
1469+
1470+
return cast<pi_result>(clReleaseContext(cast<cl_context>(Context)));
1471+
}
1472+
13891473
pi_result piPluginInit(pi_plugin *PluginInit) {
13901474
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
13911475
if (CompareVersions < 0) {
@@ -1397,6 +1481,8 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
13971481
// PI interface supports higher version or the same version.
13981482
strncpy(PluginInit->PluginVersion, SupportedVersion, 4);
13991483

1484+
ExtFuncsCaches = new ExtFuncsCachesT;
1485+
14001486
#define _PI_CL(pi_api, ocl_api) \
14011487
(PluginInit->PiFunctionTable).pi_api = (decltype(&::pi_api))(&ocl_api);
14021488

@@ -1420,7 +1506,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
14201506
_PI_CL(piContextCreate, piContextCreate)
14211507
_PI_CL(piContextGetInfo, clGetContextInfo)
14221508
_PI_CL(piContextRetain, clRetainContext)
1423-
_PI_CL(piContextRelease, clReleaseContext)
1509+
_PI_CL(piContextRelease, piContextRelease)
14241510
_PI_CL(piextContextGetNativeHandle, piextContextGetNativeHandle)
14251511
_PI_CL(piextContextCreateWithNativeHandle, piextContextCreateWithNativeHandle)
14261512
// Queue

sycl/test/abi/pi_opencl_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# UNSUPPORTED: libcxx
99

1010
piContextCreate
11+
piContextRelease
1112
piDeviceGetInfo
1213
piDevicesGet
1314
piEnqueueMemBufferMap

0 commit comments

Comments
 (0)