Skip to content

[SYCL] Fix static destruction order issue in OpenCL extension fptr cache #9254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 125 additions & 71 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -186,16 +187,70 @@ static cl_int checkDeviceExtensions(cl_device_id dev,
return ret_err;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueWriteGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t,
const void *, cl_uint, const cl_event *, cl_event *);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueReadGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t, void *,
cl_uint, const cl_event *, cl_event *);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

template <typename T> struct FuncPtrCache {
std::map<pi_context, T> Map;
std::mutex Mutex;
};

// FIXME: There's currently no mechanism for cleaning up this cache, meaning
// that it is invalidated whenever a context is destroyed. This could lead to
// reusing an invalid function pointer if another context happends to have the
// same native handle.
struct ExtFuncPtrCacheT {
FuncPtrCache<clHostMemAllocINTEL_fn> clHostMemAllocINTELCache;
FuncPtrCache<clDeviceMemAllocINTEL_fn> clDeviceMemAllocINTELCache;
Comment on lines +216 to +217
Copy link
Contributor

@aelovikov-intel aelovikov-intel Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use an std::tuple here so that a particular cache could be accessed via

  tuple.get<FuncPtrCache<ExtFuncTy>>()

?

There might be a minor issue with the fact that C++'s using Ty = doesn't create a new type, so we might need to change the typedefs above to something like

Struct ExtFuncNameTy {
  using FnPtr = /* current typedef */;
};

in order to introduce a unique type alias, but I think it will be worth it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly the issue, to go this route we would need to provide unique type aliases for all functions used including the ones defined in the OpenCL headers. I think it will end up being just as verbose as the current version.

FuncPtrCache<clSharedMemAllocINTEL_fn> clSharedMemAllocINTELCache;
FuncPtrCache<clGetDeviceFunctionPointer_fn> clGetDeviceFunctionPointerCache;
FuncPtrCache<clCreateBufferWithPropertiesINTEL_fn>
clCreateBufferWithPropertiesINTELCache;
FuncPtrCache<clMemBlockingFreeINTEL_fn> clMemBlockingFreeINTELCache;
FuncPtrCache<clSetKernelArgMemPointerINTEL_fn>
clSetKernelArgMemPointerINTELCache;
FuncPtrCache<clEnqueueMemsetINTEL_fn> clEnqueueMemsetINTELCache;
FuncPtrCache<clEnqueueMemcpyINTEL_fn> clEnqueueMemcpyINTELCache;
FuncPtrCache<clGetMemAllocInfoINTEL_fn> clGetMemAllocInfoINTELCache;
FuncPtrCache<clEnqueueWriteGlobalVariable_fn>
clEnqueueWriteGlobalVariableCache;
FuncPtrCache<clEnqueueReadGlobalVariable_fn> clEnqueueReadGlobalVariableCache;
FuncPtrCache<clEnqueueReadHostPipeINTEL_fn> clEnqueueReadHostPipeINTELCache;
FuncPtrCache<clEnqueueWriteHostPipeINTEL_fn> clEnqueueWriteHostPipeINTELCache;
FuncPtrCache<clSetProgramSpecializationConstant_fn>
clSetProgramSpecializationConstantCache;
};
// A raw pointer is used here since the lifetime of this map has to be tied to
// piTeardown to avoid issues with static destruction order (a user application
// might have static objects that indirectly access this cache in their
// destructor).
static ExtFuncPtrCacheT *ExtFuncPtrCache = new ExtFuncPtrCacheT();

// USM helper function to get an extension function pointer
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
template <typename T>
static pi_result getExtFuncFromContext(pi_context context,
FuncPtrCache<T> &FPtrCache,
const char *FuncName, T *fptr) {
// TODO
// Potentially redo caching as PI interface changes.
thread_local static std::map<pi_context, T> FuncPtrs;

// if cached, return cached FuncPtr
auto It = FuncPtrs.find(context);
if (It != FuncPtrs.end()) {
std::lock_guard<std::mutex> CacheLock{FPtrCache.Mutex};
std::map<pi_context, T> &FPtrMap = FPtrCache.Map;
auto It = FPtrMap.find(context);
if (It != FPtrMap.end()) {
auto F = It->second;
// if cached that extension is not available return nullptr and
// PI_ERROR_INVALID_VALUE
Expand Down Expand Up @@ -234,12 +289,12 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {

if (!FuncPtr) {
// Cache that the extension is not available
FuncPtrs[context] = nullptr;
FPtrMap[context] = nullptr;
return PI_ERROR_INVALID_VALUE;
}

*fptr = FuncPtr;
FuncPtrs[context] = FuncPtr;
FPtrMap[context] = FuncPtr;

return cast<pi_result>(ret_err);
}
Expand All @@ -262,24 +317,27 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
return cast<pi_result>(CLErr);
}

getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &HFunc);
getExtFuncFromContext<clHostMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clHostMemAllocINTELCache,
clHostMemAllocName, &HFunc);
if (HFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal);
}

getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &DFunc);
getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clDeviceMemAllocINTELCache,
clDeviceMemAllocName, &DFunc);
if (DFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal);
}

getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
cast<pi_context>(CLContext), &SFunc);
getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clSharedMemAllocINTELCache,
clSharedMemAllocName, &SFunc);
if (SFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
Expand Down Expand Up @@ -1090,9 +1148,6 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
return false;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);
pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
const char *func_name,
pi_uint64 *function_pointer_ret) {
Expand All @@ -1106,9 +1161,10 @@ pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
return cast<pi_result>(ret_err);

clGetDeviceFunctionPointer_fn FuncT = nullptr;
ret_err = getExtFuncFromContext<clGetDeviceFunctionPointerName,
clGetDeviceFunctionPointer_fn>(
cast<pi_context>(CLContext), &FuncT);
ret_err = getExtFuncFromContext<clGetDeviceFunctionPointer_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clGetDeviceFunctionPointerCache,
clGetDeviceFunctionPointerName, &FuncT);

pi_result pi_ret_err = PI_SUCCESS;

Expand Down Expand Up @@ -1225,9 +1281,9 @@ pi_result piMemBufferCreate(pi_context context, pi_mem_flags flags, size_t size,
// ignore unsupported
clCreateBufferWithPropertiesINTEL_fn FuncPtr = nullptr;
// First we need to look up the function pointer
ret_err = getExtFuncFromContext<clCreateBufferWithPropertiesName,
clCreateBufferWithPropertiesINTEL_fn>(
context, &FuncPtr);
ret_err = getExtFuncFromContext<clCreateBufferWithPropertiesINTEL_fn>(
context, ExtFuncPtrCache->clCreateBufferWithPropertiesINTELCache,
clCreateBufferWithPropertiesName, &FuncPtr);
if (FuncPtr) {
*ret_mem = cast<pi_mem>(FuncPtr(cast<cl_context>(context), properties,
cast<cl_mem_flags>(flags), size, host_ptr,
Expand Down Expand Up @@ -1506,8 +1562,9 @@ pi_result piextUSMHostAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clHostMemAllocINTEL_fn FuncPtr = nullptr;
RetVal = getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
context, &FuncPtr);
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clHostMemAllocINTELCache, clHostMemAllocName,
&FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context),
Expand Down Expand Up @@ -1543,9 +1600,9 @@ pi_result piextUSMDeviceAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
RetVal =
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
context, &FuncPtr);
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clDeviceMemAllocINTELCache,
clDeviceMemAllocName, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
Expand Down Expand Up @@ -1581,9 +1638,9 @@ pi_result piextUSMSharedAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
RetVal =
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
context, &FuncPtr);
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clSharedMemAllocINTELCache,
clSharedMemAllocName, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
Expand All @@ -1609,9 +1666,9 @@ pi_result piextUSMFree(pi_context context, void *ptr) {
clMemBlockingFreeINTEL_fn FuncPtr = nullptr;

pi_result RetVal = PI_ERROR_INVALID_OPERATION;
RetVal =
getExtFuncFromContext<clMemBlockingFreeName, clMemBlockingFreeINTEL_fn>(
context, &FuncPtr);
RetVal = getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
context, ExtFuncPtrCache->clMemBlockingFreeINTELCache,
clMemBlockingFreeName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
Expand Down Expand Up @@ -1642,9 +1699,10 @@ pi_result piextKernelSetArgPointer(pi_kernel kernel, pi_uint32 arg_index,
}

clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerName,
clSetKernelArgMemPointerINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clSetKernelArgMemPointerINTELCache,
clSetKernelArgMemPointerName, &FuncPtr);

if (FuncPtr) {
// OpenCL passes pointers by value not by reference
Expand Down Expand Up @@ -1683,9 +1741,9 @@ pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr, pi_int32 value,
}

clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
pi_result RetVal =
getExtFuncFromContext<clEnqueueMemsetName, clEnqueueMemsetINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemsetINTELCache,
clEnqueueMemsetName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), ptr, value,
Expand Down Expand Up @@ -1723,9 +1781,9 @@ pi_result piextUSMEnqueueMemcpy(pi_queue queue, pi_bool blocking, void *dst_ptr,
}

clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
pi_result RetVal =
getExtFuncFromContext<clEnqueueMemcpyName, clEnqueueMemcpyINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
clEnqueueMemcpyName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(
Expand Down Expand Up @@ -1949,9 +2007,9 @@ pi_result piextUSMGetMemAllocInfo(pi_context context, const void *ptr,
size_t *param_value_size_ret) {

clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
pi_result RetVal =
getExtFuncFromContext<clGetMemAllocInfoName, clGetMemAllocInfoINTEL_fn>(
context, &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
context, ExtFuncPtrCache->clGetMemAllocInfoINTELCache,
clGetMemAllocInfoName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,
Expand All @@ -1962,14 +2020,6 @@ pi_result piextUSMGetMemAllocInfo(pi_context context, const void *ptr,
return RetVal;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueWriteGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t,
const void *, cl_uint, const cl_event *, cl_event *);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueReadGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t, void *,
cl_uint, const cl_event *, cl_event *);

/// API for writing data from host to a device global variable.
///
/// \param queue is the queue
Expand Down Expand Up @@ -1997,8 +2047,9 @@ pi_result piextEnqueueDeviceGlobalVariableWrite(
return cast<pi_result>(Res);

clEnqueueWriteGlobalVariable_fn F = nullptr;
Res = getExtFuncFromContext<clEnqueueWriteGlobalVariableName, decltype(F)>(
cast<pi_context>(Ctx), &F);
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
clEnqueueWriteGlobalVariableName, &F);

if (!F || Res != CL_SUCCESS)
return PI_ERROR_INVALID_OPERATION;
Expand Down Expand Up @@ -2034,8 +2085,9 @@ pi_result piextEnqueueDeviceGlobalVariableRead(
return cast<pi_result>(Res);

clEnqueueReadGlobalVariable_fn F = nullptr;
Res = getExtFuncFromContext<clEnqueueReadGlobalVariableName, decltype(F)>(
cast<pi_context>(Ctx), &F);
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
clEnqueueReadGlobalVariableName, &F);

if (!F || Res != CL_SUCCESS)
return PI_ERROR_INVALID_OPERATION;
Expand All @@ -2060,9 +2112,10 @@ pi_result piextEnqueueReadHostPipe(pi_queue queue, pi_program program,
}

clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueReadHostPipeName,
clEnqueueReadHostPipeINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clEnqueueReadHostPipeINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
clEnqueueReadHostPipeName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(
Expand All @@ -2089,9 +2142,10 @@ pi_result piextEnqueueWriteHostPipe(pi_queue queue, pi_program program,
}

clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueWriteHostPipeName,
clEnqueueWriteHostPipeINTEL_fn>(
cast<pi_context>(CLContext), &FuncPtr);
pi_result RetVal = getExtFuncFromContext<clEnqueueWriteHostPipeINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
clEnqueueWriteHostPipeName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(
Expand Down Expand Up @@ -2126,10 +2180,6 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
}
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

pi_result piextProgramSetSpecializationConstant(pi_program prog,
pi_uint32 spec_id,
size_t spec_size,
Expand All @@ -2144,8 +2194,10 @@ pi_result piextProgramSetSpecializationConstant(pi_program prog,
return cast<pi_result>(Res);

clSetProgramSpecializationConstant_fn F = nullptr;
Res = getExtFuncFromContext<clSetProgramSpecializationConstantName,
decltype(F)>(cast<pi_context>(Ctx), &F);
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx),
ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
clSetProgramSpecializationConstantName, &F);

if (!F || Res != CL_SUCCESS)
return PI_ERROR_INVALID_OPERATION;
Expand Down Expand Up @@ -2213,9 +2265,11 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
// called safely. But this is not transitive. If the PI plugin in turn
// dynamically loaded a different DLL, that may have been unloaded.
// TODO: add a global variable lifetime management code here (see
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
// pi_level_zero.cpp for reference).
pi_result piTearDown(void *PluginParameter) {
(void)PluginParameter;
delete ExtFuncPtrCache;
ExtFuncPtrCache = nullptr;
return PI_SUCCESS;
}

Expand Down