Skip to content

Commit 20937f4

Browse files
committed
Cache Extension Function Pointers per context using thread local storage
Signed-off-by: James Brodman <[email protected]>
1 parent df86491 commit 20937f4

File tree

1 file changed

+71
-59
lines changed

1 file changed

+71
-59
lines changed

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cassert>
1212
#include <cstring>
13+
#include <map>
1314
#include <string>
1415
#include <vector>
1516

@@ -32,6 +33,14 @@ template <class To, class From> To cast(From value) {
3233
// USM helper function to get an extension function pointer
3334
template <typename T>
3435
pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
36+
thread_local static std::map<pi_context, T> FuncPtrs;
37+
38+
// if cached, return cached FuncPtr
39+
if (auto F = FuncPtrs[context]) {
40+
*fptr = F;
41+
return PI_SUCCESS;
42+
}
43+
3544
size_t deviceCount;
3645
cl_int ret_err = clGetContextInfo(
3746
cast<cl_context>(context), CL_CONTEXT_DEVICES, 0, nullptr, &deviceCount);
@@ -64,6 +73,8 @@ pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
6473
}
6574

6675
*fptr = FuncPtr;
76+
FuncPtrs[context] = FuncPtr;
77+
6778
return cast<pi_result>(ret_err);
6879
}
6980

@@ -511,16 +522,14 @@ pi_result OCL(piextUSMHostAlloc)(void **result_ptr, pi_context context,
511522
pi_result RetVal = PI_INVALID_OPERATION;
512523

513524
// First we need to look up the function pointer
514-
// It would be good if we could store it for future reuse
515-
// Are statics ok?
516-
clHostMemAllocINTEL_fn FuncPtr;
525+
clHostMemAllocINTEL_fn FuncPtr = nullptr;
517526
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
518527
context, "clHostMemAllocINTEL", &FuncPtr);
519528

520-
if (RetVal == PI_SUCCESS) {
529+
if (FuncPtr) {
521530
Ptr = FuncPtr(cast<cl_context>(context),
522-
cast<cl_mem_properties_intel *>(properties), size, alignment,
523-
cast<cl_int *>(&RetVal));
531+
cast<cl_mem_properties_intel *>(properties), size, alignment,
532+
cast<cl_int *>(&RetVal));
524533
}
525534

526535
*result_ptr = Ptr;
@@ -545,13 +554,11 @@ pi_result OCL(piextUSMDeviceAlloc)(void **result_ptr, pi_context context,
545554
pi_result RetVal = PI_INVALID_OPERATION;
546555

547556
// First we need to look up the function pointer
548-
// It would be good if we could store it for future reuse
549-
// Are statics ok?
550-
clDeviceMemAllocINTEL_fn FuncPtr;
557+
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
551558
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
552559
context, "clDeviceMemAllocINTEL", &FuncPtr);
553560

554-
if (RetVal == PI_SUCCESS) {
561+
if (FuncPtr) {
555562
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
556563
cast<cl_mem_properties_intel *>(properties), size, alignment,
557564
cast<cl_int *>(&RetVal));
@@ -579,13 +586,11 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,
579586
pi_result RetVal = PI_INVALID_OPERATION;
580587

581588
// First we need to look up the function pointer
582-
// It would be good if we could store it for future reuse
583-
// Are statics ok?
584-
clSharedMemAllocINTEL_fn FuncPtr;
589+
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
585590
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
586591
context, "clSharedMemAllocINTEL", &FuncPtr);
587592

588-
if (RetVal == PI_SUCCESS) {
593+
if (FuncPtr) {
589594
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
590595
cast<cl_mem_properties_intel *>(properties), size, alignment,
591596
cast<cl_int *>(&RetVal));
@@ -602,14 +607,16 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,
602607
/// @param ptr is the memory to be freed
603608
pi_result OCL(piextUSMFree)(pi_context context, void *ptr) {
604609

605-
clMemFreeINTEL_fn FuncPtr;
606-
pi_result RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(
607-
context, "clMemFreeINTEL", &FuncPtr);
610+
clMemFreeINTEL_fn FuncPtr = nullptr;
611+
pi_result RetVal = PI_INVALID_OPERATION;
612+
RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(context, "clMemFreeINTEL",
613+
&FuncPtr);
608614

609-
if (RetVal != PI_SUCCESS)
610-
return RetVal;
615+
if (FuncPtr) {
616+
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
617+
}
611618

612-
return cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
619+
return RetVal;
613620
}
614621

615622
/// Sets up pointer arguments for CL kernels. An extra indirection
@@ -634,18 +641,19 @@ pi_result OCL(piextUSMKernelSetArgMemPointer)(pi_kernel kernel,
634641
return cast<pi_result>(CLErr);
635642
}
636643

637-
clSetKernelArgMemPointerINTEL_fn FuncPtr;
644+
clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
638645
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
639646
cast<pi_context>(CLContext), "clSetKernelArgMemPointerINTEL", &FuncPtr);
640647

641-
if (RetVal != PI_SUCCESS)
642-
return RetVal;
648+
if (FuncPtr) {
649+
// OpenCL passes pointers by value not by reference
650+
// This means we need to deref the arg to get the pointer value
651+
auto PtrToPtr = reinterpret_cast<const intptr_t *>(arg_value);
652+
auto DerefPtr = reinterpret_cast<void *>(*PtrToPtr);
653+
RetVal = cast<pi_result>(FuncPtr(cast<cl_kernel>(kernel), arg_index, DerefPtr));
654+
}
643655

644-
// OpenCL passes pointers by value not by reference
645-
// This means we need to deref the arg to get the pointer value
646-
auto PtrToPtr = reinterpret_cast<const intptr_t *>(arg_value);
647-
auto DerefPtr = reinterpret_cast<void *>(*PtrToPtr);
648-
return cast<pi_result>(FuncPtr(cast<cl_kernel>(kernel), arg_index, DerefPtr));
656+
return RetVal;
649657
}
650658

651659
/// Enables indirect access of pointers in kernels.
@@ -657,35 +665,35 @@ pi_result OCL(piextUSMKernelSetIndirectAccess)(pi_kernel kernel) {
657665
// We test that each alloc type is supported before we actually try to
658666
// set KernelExecInfo.
659667
cl_bool TrueVal = CL_TRUE;
660-
pi_result RetVal;
661-
clHostMemAllocINTEL_fn HFunc;
662-
clSharedMemAllocINTEL_fn SFunc;
663-
clDeviceMemAllocINTEL_fn DFunc;
668+
clHostMemAllocINTEL_fn HFunc = nullptr;
669+
clSharedMemAllocINTEL_fn SFunc = nullptr;
670+
clDeviceMemAllocINTEL_fn DFunc = nullptr;
664671
cl_context CLContext;
665672
cl_int CLErr = clGetKernelInfo(cast<cl_kernel>(kernel), CL_KERNEL_CONTEXT,
666673
sizeof(cl_context), &CLContext, nullptr);
667674
if (CLErr != CL_SUCCESS) {
668675
return cast<pi_result>(CLErr);
669676
}
670677

671-
// This would be really good to cache
672-
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
673-
cast<pi_context>(CLContext), "clHostMemAllocINTEL", &HFunc);
674-
if (RetVal == PI_SUCCESS) {
678+
getExtFuncFromContext<clHostMemAllocINTEL_fn>(cast<pi_context>(CLContext),
679+
"clHostMemAllocINTEL", &HFunc);
680+
if (HFunc) {
675681
clSetKernelExecInfo(cast<cl_kernel>(kernel),
676682
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
677683
sizeof(cl_bool), &TrueVal);
678684
}
679-
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
685+
686+
getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
680687
cast<pi_context>(CLContext), "clDeviceMemAllocINTEL", &DFunc);
681-
if (RetVal == PI_SUCCESS) {
688+
if (DFunc) {
682689
clSetKernelExecInfo(cast<cl_kernel>(kernel),
683690
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
684691
sizeof(cl_bool), &TrueVal);
685692
}
686-
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
693+
694+
getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
687695
cast<pi_context>(CLContext), "clSharedMemAllocINTEL", &SFunc);
688-
if (RetVal == PI_SUCCESS) {
696+
if (SFunc) {
689697
clSetKernelExecInfo(cast<cl_kernel>(kernel),
690698
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
691699
sizeof(cl_bool), &TrueVal);
@@ -718,16 +726,18 @@ pi_result OCL(piextUSMEnqueueMemset)(pi_queue queue, void *ptr, pi_int32 value,
718726
return cast<pi_result>(CLErr);
719727
}
720728

721-
clEnqueueMemsetINTEL_fn FuncPtr;
729+
clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
722730
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
723731
cast<pi_context>(CLContext), "clEnqueueMemsetINTEL", &FuncPtr);
724732

725-
if (RetVal != PI_SUCCESS)
726-
return RetVal;
733+
if (FuncPtr) {
734+
RetVal = cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), ptr, value,
735+
count, num_events_in_waitlist,
736+
cast<const cl_event *>(events_waitlist),
737+
cast<cl_event *>(event)));
738+
}
727739

728-
return cast<pi_result>(FuncPtr(
729-
cast<cl_command_queue>(queue), ptr, value, count, num_events_in_waitlist,
730-
cast<const cl_event *>(events_waitlist), cast<cl_event *>(event)));
740+
return RetVal;
731741
}
732742

733743
/// USM Memcpy API
@@ -756,17 +766,18 @@ pi_result OCL(piextUSMEnqueueMemcpy)(pi_queue queue, pi_bool blocking,
756766
return cast<pi_result>(CLErr);
757767
}
758768

759-
clEnqueueMemcpyINTEL_fn FuncPtr;
769+
clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
760770
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
761771
cast<pi_context>(CLContext), "clEnqueueMemcpyINTEL", &FuncPtr);
762772

763-
if (RetVal != PI_SUCCESS)
764-
return RetVal;
773+
if (FuncPtr) {
774+
RetVal = cast<pi_result>(
775+
FuncPtr(cast<cl_command_queue>(queue), blocking, dst_ptr, src_ptr, size,
776+
num_events_in_waitlist, cast<const cl_event *>(events_waitlist),
777+
cast<cl_event *>(event)));
778+
}
765779

766-
return cast<pi_result>(FuncPtr(cast<cl_command_queue>(queue), blocking,
767-
dst_ptr, src_ptr, size, num_events_in_waitlist,
768-
cast<const cl_event *>(events_waitlist),
769-
cast<cl_event *>(event)));
780+
return RetVal;
770781
}
771782

772783
/// Hint to migrate memory to the device
@@ -879,16 +890,17 @@ pi_result OCL(piextUSMGetMemAllocInfo)(pi_context context, const void *ptr,
879890
void *param_value,
880891
size_t *param_value_size_ret) {
881892

882-
clGetMemAllocInfoINTEL_fn FuncPtr;
893+
clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
883894
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
884895
context, "clGetMemAllocInfoINTEL", &FuncPtr);
885896

886-
if (RetVal != PI_SUCCESS)
887-
return RetVal;
897+
if (FuncPtr) {
898+
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,
899+
param_value_size, param_value,
900+
param_value_size_ret));
901+
}
888902

889-
return cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr, param_name,
890-
param_value_size, param_value,
891-
param_value_size_ret));
903+
return RetVal;
892904
}
893905

894906
pi_result piPluginInit(pi_plugin *PluginInit) {

0 commit comments

Comments
 (0)