Skip to content

[SYCL][NFC] Minor stylistic changes related to getExtFuncFromContext #9281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 52 additions & 51 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,24 +187,26 @@ static cl_int checkDeviceExtensions(cl_device_id dev,
return ret_err;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);
using clGetDeviceFunctionPointer_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_device_id device, cl_program program,
const char *FuncName, cl_ulong *ret_ptr);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueWriteGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t,
const void *, cl_uint, const cl_event *, cl_event *);
using clEnqueueWriteGlobalVariable_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_command_queue, cl_program, const char *, cl_bool,
size_t, size_t, const void *, cl_uint, const cl_event *,
cl_event *);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clEnqueueReadGlobalVariable_fn)(
cl_command_queue, cl_program, const char *, cl_bool, size_t, size_t, void *,
cl_uint, const cl_event *, cl_event *);
using clEnqueueReadGlobalVariable_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_command_queue, cl_program, const char *, cl_bool,
size_t, size_t, void *, cl_uint, const cl_event *,
cl_event *);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);
using clSetProgramSpecializationConstant_fn = CL_API_ENTRY
cl_int(CL_API_CALL *)(cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

template <typename T> struct FuncPtrCache {
std::map<pi_context, T> Map;
std::map<cl_context, T> Map;
std::mutex Mutex;
};

Expand Down Expand Up @@ -241,14 +243,14 @@ static ExtFuncPtrCacheT *ExtFuncPtrCache = new ExtFuncPtrCacheT();

// USM helper function to get an extension function pointer
template <typename T>
static pi_result getExtFuncFromContext(pi_context context,
static pi_result getExtFuncFromContext(cl_context context,
FuncPtrCache<T> &FPtrCache,
const char *FuncName, T *fptr) {
// TODO
// Potentially redo caching as PI interface changes.
// if cached, return cached FuncPtr
std::lock_guard<std::mutex> CacheLock{FPtrCache.Mutex};
std::map<pi_context, T> &FPtrMap = FPtrCache.Map;
std::map<cl_context, T> &FPtrMap = FPtrCache.Map;
auto It = FPtrMap.find(context);
if (It != FPtrMap.end()) {
auto F = It->second;
Expand All @@ -259,16 +261,15 @@ static pi_result getExtFuncFromContext(pi_context context,
}

cl_uint deviceCount;
cl_int ret_err =
clGetContextInfo(cast<cl_context>(context), CL_CONTEXT_NUM_DEVICES,
sizeof(cl_uint), &deviceCount, nullptr);
cl_int ret_err = clGetContextInfo(context, CL_CONTEXT_NUM_DEVICES,
sizeof(cl_uint), &deviceCount, nullptr);

if (ret_err != CL_SUCCESS || deviceCount < 1) {
return PI_ERROR_INVALID_CONTEXT;
}

std::vector<cl_device_id> devicesInCtx(deviceCount);
ret_err = clGetContextInfo(cast<cl_context>(context), CL_CONTEXT_DEVICES,
ret_err = clGetContextInfo(context, CL_CONTEXT_DEVICES,
deviceCount * sizeof(cl_device_id),
devicesInCtx.data(), nullptr);

Expand Down Expand Up @@ -318,16 +319,16 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
}

getExtFuncFromContext<clHostMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clHostMemAllocINTELCache,
clHostMemAllocName, &HFunc);
CLContext, ExtFuncPtrCache->clHostMemAllocINTELCache, clHostMemAllocName,
&HFunc);
if (HFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
sizeof(cl_bool), &TrueVal);
}

getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clDeviceMemAllocINTELCache,
CLContext, ExtFuncPtrCache->clDeviceMemAllocINTELCache,
clDeviceMemAllocName, &DFunc);
if (DFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
Expand All @@ -336,7 +337,7 @@ static pi_result USMSetIndirectAccess(pi_kernel kernel) {
}

getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clSharedMemAllocINTELCache,
CLContext, ExtFuncPtrCache->clSharedMemAllocINTELCache,
clSharedMemAllocName, &SFunc);
if (SFunc) {
clSetKernelExecInfo(cast<cl_kernel>(kernel),
Expand Down Expand Up @@ -1172,8 +1173,7 @@ pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,

clGetDeviceFunctionPointer_fn FuncT = nullptr;
ret_err = getExtFuncFromContext<clGetDeviceFunctionPointer_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clGetDeviceFunctionPointerCache,
CLContext, ExtFuncPtrCache->clGetDeviceFunctionPointerCache,
clGetDeviceFunctionPointerName, &FuncT);

pi_result pi_ret_err = PI_SUCCESS;
Expand Down Expand Up @@ -1290,14 +1290,15 @@ pi_result piMemBufferCreate(pi_context context, pi_mem_flags flags, size_t size,
// TODO: need to check if all properties are supported by OpenCL RT and
// ignore unsupported
clCreateBufferWithPropertiesINTEL_fn FuncPtr = nullptr;
cl_context CLContext = cast<cl_context>(context);
// First we need to look up the function pointer
ret_err = getExtFuncFromContext<clCreateBufferWithPropertiesINTEL_fn>(
context, ExtFuncPtrCache->clCreateBufferWithPropertiesINTELCache,
CLContext, ExtFuncPtrCache->clCreateBufferWithPropertiesINTELCache,
clCreateBufferWithPropertiesName, &FuncPtr);
if (FuncPtr) {
*ret_mem = cast<pi_mem>(FuncPtr(cast<cl_context>(context), properties,
cast<cl_mem_flags>(flags), size, host_ptr,
cast<cl_int *>(&ret_err)));
*ret_mem =
cast<pi_mem>(FuncPtr(CLContext, properties, cast<cl_mem_flags>(flags),
size, host_ptr, cast<cl_int *>(&ret_err)));
return ret_err;
}
}
Expand Down Expand Up @@ -1572,14 +1573,14 @@ pi_result piextUSMHostAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clHostMemAllocINTEL_fn FuncPtr = nullptr;
cl_context CLContext = cast<cl_context>(context);
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clHostMemAllocINTELCache, clHostMemAllocName,
CLContext, ExtFuncPtrCache->clHostMemAllocINTELCache, clHostMemAllocName,
&FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context),
cast<cl_mem_properties_intel *>(properties), size, alignment,
cast<cl_int *>(&RetVal));
Ptr = FuncPtr(CLContext, cast<cl_mem_properties_intel *>(properties), size,
alignment, cast<cl_int *>(&RetVal));
}

*result_ptr = Ptr;
Expand Down Expand Up @@ -1610,12 +1611,13 @@ pi_result piextUSMDeviceAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clDeviceMemAllocINTEL_fn FuncPtr = nullptr;
cl_context CLContext = cast<cl_context>(context);
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clDeviceMemAllocINTELCache,
CLContext, ExtFuncPtrCache->clDeviceMemAllocINTELCache,
clDeviceMemAllocName, &FuncPtr);

if (FuncPtr) {
Ptr = FuncPtr(cast<cl_context>(context), cast<cl_device_id>(device),
Ptr = FuncPtr(CLContext, cast<cl_device_id>(device),
cast<cl_mem_properties_intel *>(properties), size, alignment,
cast<cl_int *>(&RetVal));
}
Expand Down Expand Up @@ -1648,8 +1650,9 @@ pi_result piextUSMSharedAlloc(void **result_ptr, pi_context context,

// First we need to look up the function pointer
clSharedMemAllocINTEL_fn FuncPtr = nullptr;
cl_context CLContext = cast<cl_context>(context);
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
context, ExtFuncPtrCache->clSharedMemAllocINTELCache,
CLContext, ExtFuncPtrCache->clSharedMemAllocINTELCache,
clSharedMemAllocName, &FuncPtr);

if (FuncPtr) {
Expand All @@ -1675,13 +1678,14 @@ pi_result piextUSMFree(pi_context context, void *ptr) {
// might be still running.
clMemBlockingFreeINTEL_fn FuncPtr = nullptr;

cl_context CLContext = cast<cl_context>(context);
pi_result RetVal = PI_ERROR_INVALID_OPERATION;
RetVal = getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
context, ExtFuncPtrCache->clMemBlockingFreeINTELCache,
CLContext, ExtFuncPtrCache->clMemBlockingFreeINTELCache,
clMemBlockingFreeName, &FuncPtr);

if (FuncPtr) {
RetVal = cast<pi_result>(FuncPtr(cast<cl_context>(context), ptr));
RetVal = cast<pi_result>(FuncPtr(CLContext, ptr));
}

return RetVal;
Expand Down Expand Up @@ -1710,8 +1714,7 @@ pi_result piextKernelSetArgPointer(pi_kernel kernel, pi_uint32 arg_index,

clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clSetKernelArgMemPointerINTELCache,
CLContext, ExtFuncPtrCache->clSetKernelArgMemPointerINTELCache,
clSetKernelArgMemPointerName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -1752,7 +1755,7 @@ pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr, pi_int32 value,

clEnqueueMemsetINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemsetINTELCache,
CLContext, ExtFuncPtrCache->clEnqueueMemsetINTELCache,
clEnqueueMemsetName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -1792,7 +1795,7 @@ pi_result piextUSMEnqueueMemcpy(pi_queue queue, pi_bool blocking, void *dst_ptr,

clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
cast<pi_context>(CLContext), ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
CLContext, ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
clEnqueueMemcpyName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -2017,8 +2020,9 @@ pi_result piextUSMGetMemAllocInfo(pi_context context, const void *ptr,
size_t *param_value_size_ret) {

clGetMemAllocInfoINTEL_fn FuncPtr = nullptr;
cl_context CLContext = cast<cl_context>(context);
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
context, ExtFuncPtrCache->clGetMemAllocInfoINTELCache,
CLContext, ExtFuncPtrCache->clGetMemAllocInfoINTELCache,
clGetMemAllocInfoName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -2058,7 +2062,7 @@ pi_result piextEnqueueDeviceGlobalVariableWrite(

clEnqueueWriteGlobalVariable_fn F = nullptr;
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
Ctx, ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
clEnqueueWriteGlobalVariableName, &F);

if (!F || Res != CL_SUCCESS)
Expand Down Expand Up @@ -2096,7 +2100,7 @@ pi_result piextEnqueueDeviceGlobalVariableRead(

clEnqueueReadGlobalVariable_fn F = nullptr;
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx), ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
Ctx, ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
clEnqueueReadGlobalVariableName, &F);

if (!F || Res != CL_SUCCESS)
Expand All @@ -2123,8 +2127,7 @@ pi_result piextEnqueueReadHostPipe(pi_queue queue, pi_program program,

clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueReadHostPipeINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
CLContext, ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
clEnqueueReadHostPipeName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -2153,8 +2156,7 @@ pi_result piextEnqueueWriteHostPipe(pi_queue queue, pi_program program,

clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
pi_result RetVal = getExtFuncFromContext<clEnqueueWriteHostPipeINTEL_fn>(
cast<pi_context>(CLContext),
ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
CLContext, ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
clEnqueueWriteHostPipeName, &FuncPtr);

if (FuncPtr) {
Expand Down Expand Up @@ -2205,8 +2207,7 @@ pi_result piextProgramSetSpecializationConstant(pi_program prog,

clSetProgramSpecializationConstant_fn F = nullptr;
Res = getExtFuncFromContext<decltype(F)>(
cast<pi_context>(Ctx),
ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
Ctx, ExtFuncPtrCache->clSetProgramSpecializationConstantCache,
clSetProgramSpecializationConstantName, &F);

if (!F || Res != CL_SUCCESS)
Expand Down