25
25
#include < iostream>
26
26
#include < limits>
27
27
#include < map>
28
+ #include < memory>
29
+ #include < mutex>
28
30
#include < sstream>
29
31
#include < string>
30
32
#include < vector>
@@ -71,19 +73,93 @@ CONSTFIX char clGetDeviceFunctionPointerName[] =
71
73
72
74
#undef CONSTFIX
73
75
76
+ typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceFunctionPointer_fn)(
77
+ cl_device_id device, cl_program program, const char *FuncName,
78
+ cl_ulong *ret_ptr);
79
+
80
+ typedef CL_API_ENTRY cl_int (CL_API_CALL *clSetProgramSpecializationConstant_fn)(
81
+ cl_program program, cl_uint spec_id, size_t spec_size,
82
+ const void *spec_value);
83
+
84
+ struct ExtFuncsPerContextT ;
85
+
86
+ namespace detail {
87
+ template <const char *FuncName, typename FuncT>
88
+ std::pair<FuncT &, bool &> get (ExtFuncsPerContextT &);
89
+ } // namespace detail
90
+
91
+ struct ExtFuncsPerContextT {
92
+ #define _EXT_FUNCTION_INTEL (t_pfx ) \
93
+ t_pfx##INTEL_fn t_pfx##Func = nullptr ; \
94
+ bool t_pfx##Initialized = false ;
95
+
96
+ #define _EXT_FUNCTION (t_pfx ) \
97
+ t_pfx##_fn t_pfx##Func = nullptr ; \
98
+ bool t_pfx##Initialized = false ;
99
+
100
+ #include " ext_functions.inc"
101
+
102
+ #undef _EXT_FUNCTION
103
+ #undef _EXT_FUNCTION_INTEL
104
+
105
+ std::mutex Mtx;
106
+
107
+ template <const char *FuncName, typename FuncT>
108
+ std::pair<FuncT &, bool &> get () {
109
+ return detail::get<FuncName, FuncT>(*this );
110
+ }
111
+ };
112
+
113
+ namespace detail {
114
+
115
+ #define _EXT_FUNCTION_COMMON (t_pfx, t_pfx_suff ) \
116
+ template <> \
117
+ std::pair<t_pfx_suff##_fn &, bool &> get<t_pfx##Name, t_pfx_suff##_fn>( \
118
+ ExtFuncsPerContextT & Funcs) { \
119
+ using FPtrT = t_pfx_suff##_fn; \
120
+ std::pair<FPtrT &, bool &> Ret{Funcs.t_pfx ##Func, \
121
+ Funcs.t_pfx ##Initialized}; \
122
+ return Ret; \
123
+ }
124
+ #define _EXT_FUNCTION_INTEL (t_pfx ) _EXT_FUNCTION_COMMON(t_pfx, t_pfx##INTEL)
125
+ #define _EXT_FUNCTION (t_pfx ) _EXT_FUNCTION_COMMON(t_pfx, t_pfx)
126
+
127
+ #include " ext_functions.inc"
128
+
129
+ #undef _EXT_FUNCTION
130
+ #undef _EXT_FUNCTION_INTEL
131
+ #undef _EXT_FUNCTION_COMMON
132
+ } // namespace detail
133
+
134
+ struct ExtFuncsCachesT {
135
+ std::map<pi_context, ExtFuncsPerContextT> Caches;
136
+ std::mutex Mtx;
137
+ };
138
+
139
+ ExtFuncsCachesT *ExtFuncsCaches = nullptr ;
140
+
74
141
// USM helper function to get an extension function pointer
75
142
template <const char *FuncName, typename T>
76
143
static pi_result getExtFuncFromContext (pi_context context, T *fptr) {
77
144
// TODO
78
145
// Potentially redo caching as PI interface changes.
79
- thread_local static std::map<pi_context, T> FuncPtrs;
146
+ ExtFuncsPerContextT *PerContext = nullptr ;
147
+ {
148
+ assert (ExtFuncsCaches);
149
+ std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx };
150
+
151
+ PerContext = &ExtFuncsCaches->Caches [context];
152
+ }
153
+
154
+ std::lock_guard<std::mutex> Lock{PerContext->Mtx };
155
+ std::pair<T &, bool &> FuncInitialized = PerContext->get <FuncName, T>();
80
156
81
157
// if cached, return cached FuncPtr
82
- if (auto F = FuncPtrs[context] ) {
158
+ if (FuncInitialized. second ) {
83
159
// if cached that extension is not available return nullptr and
84
160
// PI_INVALID_VALUE
85
- *fptr = F ;
86
- return F ? PI_SUCCESS : PI_INVALID_VALUE;
161
+ *fptr = FuncInitialized. first ;
162
+ return *fptr ? PI_SUCCESS : PI_INVALID_VALUE;
87
163
}
88
164
89
165
cl_uint deviceCount;
@@ -115,14 +191,17 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
115
191
T FuncPtr =
116
192
(T)clGetExtensionFunctionAddressForPlatform (curPlatform, FuncName);
117
193
194
+ // We're about to store the cached value. Mark this cache entry initialized.
195
+ FuncInitialized.second = true ;
196
+
118
197
if (!FuncPtr) {
119
198
// Cache that the extension is not available
120
- FuncPtrs[context] = nullptr ;
199
+ FuncInitialized. first = nullptr ;
121
200
return PI_INVALID_VALUE;
122
201
}
123
202
203
+ FuncInitialized.first = FuncPtr;
124
204
*fptr = FuncPtr;
125
- FuncPtrs[context] = FuncPtr;
126
205
127
206
return cast<pi_result>(ret_err);
128
207
}
@@ -561,9 +640,6 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
561
640
return false ;
562
641
}
563
642
564
- typedef CL_API_ENTRY cl_int (CL_API_CALL *clGetDeviceFunctionPointer_fn)(
565
- cl_device_id device, cl_program program, const char *FuncName,
566
- cl_ulong *ret_ptr);
567
643
pi_result piextGetDeviceFunctionPointer (pi_device device, pi_program program,
568
644
const char *func_name,
569
645
pi_uint64 *function_pointer_ret) {
@@ -1304,10 +1380,6 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
1304
1380
}
1305
1381
}
1306
1382
1307
- typedef CL_API_ENTRY cl_int (CL_API_CALL *clSetProgramSpecializationConstant_fn)(
1308
- cl_program program, cl_uint spec_id, size_t spec_size,
1309
- const void *spec_value);
1310
-
1311
1383
pi_result piextProgramSetSpecializationConstant (pi_program prog,
1312
1384
pi_uint32 spec_id,
1313
1385
size_t spec_size,
@@ -1383,9 +1455,21 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
1383
1455
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
1384
1456
pi_result piTearDown (void *PluginParameter) {
1385
1457
(void )PluginParameter;
1458
+ delete ExtFuncsCaches;
1459
+ ExtFuncsCaches = nullptr ;
1386
1460
return PI_SUCCESS;
1387
1461
}
1388
1462
1463
+ pi_result piContextRelease (pi_context Context) {
1464
+ {
1465
+ std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx };
1466
+
1467
+ ExtFuncsCaches->Caches .erase (Context);
1468
+ }
1469
+
1470
+ return cast<pi_result>(clReleaseContext (cast<cl_context>(Context)));
1471
+ }
1472
+
1389
1473
pi_result piPluginInit (pi_plugin *PluginInit) {
1390
1474
int CompareVersions = strcmp (PluginInit->PiVersion , SupportedVersion);
1391
1475
if (CompareVersions < 0 ) {
@@ -1397,6 +1481,8 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
1397
1481
// PI interface supports higher version or the same version.
1398
1482
strncpy (PluginInit->PluginVersion , SupportedVersion, 4 );
1399
1483
1484
+ ExtFuncsCaches = new ExtFuncsCachesT;
1485
+
1400
1486
#define _PI_CL (pi_api, ocl_api ) \
1401
1487
(PluginInit->PiFunctionTable ).pi_api = (decltype (&::pi_api))(&ocl_api);
1402
1488
@@ -1420,7 +1506,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
1420
1506
_PI_CL (piContextCreate, piContextCreate)
1421
1507
_PI_CL (piContextGetInfo, clGetContextInfo)
1422
1508
_PI_CL (piContextRetain, clRetainContext)
1423
- _PI_CL (piContextRelease, clReleaseContext )
1509
+ _PI_CL (piContextRelease, piContextRelease )
1424
1510
_PI_CL (piextContextGetNativeHandle, piextContextGetNativeHandle)
1425
1511
_PI_CL (piextContextCreateWithNativeHandle, piextContextCreateWithNativeHandle)
1426
1512
// Queue
0 commit comments