Skip to content

Commit fca1736

Browse files
jbrodmanromanovvlad
authored andcommitted
Fix USM function pointers caching (#1008)
Fix function pointer caching to properly distinguish different functions as func ptr types are not unique. Signed-off-by: James Brodman <[email protected]>
1 parent fba2e06 commit fca1736

File tree

1 file changed

+46
-28
lines changed

1 file changed

+46
-28
lines changed

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,21 @@ template <class To, class From> To cast(From value) {
3030
return (To)(value);
3131
}
3232

33+
// Names of USM functions that are queried from OpenCL
34+
const char clHostMemAllocName[] = "clHostMemAllocINTEL";
35+
const char clDeviceMemAllocName[] = "clDeviceMemAllocINTEL";
36+
const char clSharedMemAllocName[] = "clSharedMemAllocINTEL";
37+
const char clMemFreeName[] = "clMemFreeINTEL";
38+
const char clSetKernelArgMemPointerName[] = "clSetKernelArgMemPointerINTEL";
39+
const char clEnqueueMemsetName[] = "clEnqueueMemsetINTEL";
40+
const char clEnqueueMemcpyName[] = "clEnqueueMemcpyINTEL";
41+
const char clEnqueueMigrateMemName[] = "clEnqueueMigrateMemINTEL";
42+
const char clEnqueueMemAdviseName[] = "clEnqueueMemAdviseINTEL";
43+
const char clGetMemAllocInfoName[] = "clGetMemAllocInfoINTEL";
44+
3345
// USM helper function to get an extension function pointer
34-
template <typename T>
35-
pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
46+
template <const char *FuncName, typename T>
47+
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
3648
// TODO
3749
// Potentially redo caching as PI interface changes.
3850
thread_local static std::map<pi_context, T> FuncPtrs;
@@ -68,11 +80,11 @@ pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
6880
return PI_INVALID_CONTEXT;
6981
}
7082

71-
T FuncPtr = (T) clGetExtensionFunctionAddressForPlatform(curPlatform,
72-
func);
73-
if (!FuncPtr) {
83+
T FuncPtr =
84+
(T)clGetExtensionFunctionAddressForPlatform(curPlatform, FuncName);
85+
86+
if (!FuncPtr)
7487
return PI_INVALID_VALUE;
75-
}
7688

7789
*fptr = FuncPtr;
7890
FuncPtrs[context] = FuncPtr;
@@ -98,24 +110,24 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
98110
return cast<pi_result>(CLErr);
99111
}
100112

101-
getExtFuncFromContext<clHostMemAllocINTEL_fn>(cast<pi_context>(CLContext),
102-
"clHostMemAllocINTEL", &HFunc);
113+
getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
114+
cast<pi_context>(CLContext), &HFunc);
103115
if (HFunc) {
104116
clSetKernelExecInfo(cast<cl_kernel>(kernel),
105117
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
106118
sizeof(cl_bool), &TrueVal);
107119
}
108120

109-
getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
110-
cast<pi_context>(CLContext), "clDeviceMemAllocINTEL", &DFunc);
121+
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
122+
cast<pi_context>(CLContext), &DFunc);
111123
if (DFunc) {
112124
clSetKernelExecInfo(cast<cl_kernel>(kernel),
113125
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
114126
sizeof(cl_bool), &TrueVal);
115127
}
116128

117-
getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
118-
cast<pi_context>(CLContext), "clSharedMemAllocINTEL", &SFunc);
129+
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
130+
cast<pi_context>(CLContext), &SFunc);
119131
if (SFunc) {
120132
clSetKernelExecInfo(cast<cl_kernel>(kernel),
121133
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
@@ -569,8 +581,8 @@ pi_result OCL(piextUSMHostAlloc)(void **result_ptr, pi_context context,
569581

570582
// First we need to look up the function pointer
571583
clHostMemAllocINTEL_fn FuncPtr = nullptr;
572-
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
573-
context, "clHostMemAllocINTEL", &FuncPtr);
584+
RetVal = getExtFuncFromContext<clHostMemAllocName, clHostMemAllocINTEL_fn>(
585+
context, &FuncPtr);
574586

575587
if (FuncPtr) {
576588
Ptr = FuncPtr(cast<cl_context>(context),
@@ -601,8 +613,9 @@ pi_result OCL(piextUSMDeviceAlloc)(void **result_ptr, pi_context context,
601613

602614
// First we need to look up the function pointer
603615
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
604-
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
605-
context, "clDeviceMemAllocINTEL", &FuncPtr);
616+
RetVal =
617+
getExtFuncFromContext<clDeviceMemAllocName, clDeviceMemAllocINTEL_fn>(
618+
context, &FuncPtr);
606619

607620
if (FuncPtr) {
608621
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
@@ -633,8 +646,9 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,
633646

634647
// First we need to look up the function pointer
635648
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
636-
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
637-
context, "clSharedMemAllocINTEL", &FuncPtr);
649+
RetVal =
650+
getExtFuncFromContext<clSharedMemAllocName, clSharedMemAllocINTEL_fn>(
651+
context, &FuncPtr);
638652

639653
if (FuncPtr) {
640654
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
@@ -655,8 +669,8 @@ pi_result OCL(piextUSMFree)(pi_context context, void *ptr) {
655669

656670
clMemFreeINTEL_fn FuncPtr = nullptr;
657671
pi_result RetVal = PI_INVALID_OPERATION;
658-
RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(context, "clMemFreeINTEL",
659-
&FuncPtr);
672+
RetVal = getExtFuncFromContext<clMemFreeName, clMemFreeINTEL_fn>(context,
673+
&FuncPtr);
660674

661675
if (FuncPtr) {
662676
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
@@ -687,8 +701,9 @@ pi_result OCL(piextKernelSetArgPointer)(pi_kernel kernel, pi_uint32 arg_index,
687701
}
688702

689703
clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
690-
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
691-
cast<pi_context>(CLContext), "clSetKernelArgMemPointerINTEL", &FuncPtr);
704+
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerName,
705+
clSetKernelArgMemPointerINTEL_fn>(
706+
cast<pi_context>(CLContext), &FuncPtr);
692707

693708
if (FuncPtr) {
694709
// OpenCL passes pointers by value not by reference
@@ -727,8 +742,9 @@ pi_result OCL(piextUSMEnqueueMemset)(pi_queue queue, void *ptr, pi_int32 value,
727742
}
728743

729744
clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
730-
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
731-
cast<pi_context>(CLContext), "clEnqueueMemsetINTEL", &FuncPtr);
745+
pi_result RetVal =
746+
getExtFuncFromContext<clEnqueueMemsetName, clEnqueueMemsetINTEL_fn>(
747+
cast<pi_context>(CLContext), &FuncPtr);
732748

733749
if (FuncPtr) {
734750
RetVal = cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), ptr, value,
@@ -767,8 +783,9 @@ pi_result OCL(piextUSMEnqueueMemcpy)(pi_queue queue, pi_bool blocking,
767783
}
768784

769785
clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
770-
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
771-
cast<pi_context>(CLContext), "clEnqueueMemcpyINTEL", &FuncPtr);
786+
pi_result RetVal =
787+
getExtFuncFromContext<clEnqueueMemcpyName, clEnqueueMemcpyINTEL_fn>(
788+
cast<pi_context>(CLContext), &FuncPtr);
772789

773790
if (FuncPtr) {
774791
RetVal = cast<pi_result>(
@@ -893,8 +910,9 @@ pi_result OCL(piextUSMGetMemAllocInfo)(pi_context context, const void *ptr,
893910
size_t *param_value_size_ret) {
894911

895912
clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
896-
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
897-
context, "clGetMemAllocInfoINTEL", &FuncPtr);
913+
pi_result RetVal =
914+
getExtFuncFromContext<clGetMemAllocInfoName, clGetMemAllocInfoINTEL_fn>(
915+
context, &FuncPtr);
898916

899917
if (FuncPtr) {
900918
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,

0 commit comments

Comments
 (0)