Skip to content

Revert "[SYCL] Clear extensions functions cache upon context release (#5282)" #5433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions sycl/plugins/opencl/ext_functions.inc

This file was deleted.

114 changes: 14 additions & 100 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
#include <iostream>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -73,93 +71,19 @@ CONSTFIX char clGetDeviceFunctionPointerName[] =

#undef CONSTFIX

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

struct ExtFuncsPerContextT;

namespace detail {
template <const char *FuncName, typename FuncT>
std::pair<FuncT &, bool &> get(ExtFuncsPerContextT &);
} // namespace detail

struct ExtFuncsPerContextT {
#define _EXT_FUNCTION_INTEL(t_pfx) \
t_pfx##INTEL_fn t_pfx##Func = nullptr; \
bool t_pfx##Initialized = false;

#define _EXT_FUNCTION(t_pfx) \
t_pfx##_fn t_pfx##Func = nullptr; \
bool t_pfx##Initialized = false;

#include "ext_functions.inc"

#undef _EXT_FUNCTION
#undef _EXT_FUNCTION_INTEL

std::mutex Mtx;

template <const char *FuncName, typename FuncT>
std::pair<FuncT &, bool &> get() {
return detail::get<FuncName, FuncT>(*this);
}
};

namespace detail {

#define _EXT_FUNCTION_COMMON(t_pfx, t_pfx_suff) \
template <> \
std::pair<t_pfx_suff##_fn &, bool &> get<t_pfx##Name, t_pfx_suff##_fn>( \
ExtFuncsPerContextT & Funcs) { \
using FPtrT = t_pfx_suff##_fn; \
std::pair<FPtrT &, bool &> Ret{Funcs.t_pfx##Func, \
Funcs.t_pfx##Initialized}; \
return Ret; \
}
#define _EXT_FUNCTION_INTEL(t_pfx) _EXT_FUNCTION_COMMON(t_pfx, t_pfx##INTEL)
#define _EXT_FUNCTION(t_pfx) _EXT_FUNCTION_COMMON(t_pfx, t_pfx)

#include "ext_functions.inc"

#undef _EXT_FUNCTION
#undef _EXT_FUNCTION_INTEL
#undef _EXT_FUNCTION_COMMON
} // namespace detail

struct ExtFuncsCachesT {
std::map<pi_context, ExtFuncsPerContextT> Caches;
std::mutex Mtx;
};

ExtFuncsCachesT *ExtFuncsCaches = nullptr;

// USM helper function to get an extension function pointer
template <const char *FuncName, typename T>
static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
// TODO
// Potentially redo caching as PI interface changes.
ExtFuncsPerContextT *PerContext = nullptr;
{
assert(ExtFuncsCaches);
std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx};

PerContext = &ExtFuncsCaches->Caches[context];
}

std::lock_guard<std::mutex> Lock{PerContext->Mtx};
std::pair<T &, bool &> FuncInitialized = PerContext->get<FuncName, T>();
thread_local static std::map<pi_context, T> FuncPtrs;

// if cached, return cached FuncPtr
if (FuncInitialized.second) {
if (auto F = FuncPtrs[context]) {
// if cached that extension is not available return nullptr and
// PI_INVALID_VALUE
*fptr = FuncInitialized.first;
return *fptr ? PI_SUCCESS : PI_INVALID_VALUE;
*fptr = F;
return F ? PI_SUCCESS : PI_INVALID_VALUE;
}

cl_uint deviceCount;
Expand Down Expand Up @@ -191,17 +115,14 @@ static pi_result getExtFuncFromContext(pi_context context, T *fptr) {
T FuncPtr =
(T)clGetExtensionFunctionAddressForPlatform(curPlatform, FuncName);

// We're about to store the cached value. Mark this cache entry initialized.
FuncInitialized.second = true;

if (!FuncPtr) {
// Cache that the extension is not available
FuncInitialized.first = nullptr;
FuncPtrs[context] = nullptr;
return PI_INVALID_VALUE;
}

FuncInitialized.first = FuncPtr;
*fptr = FuncPtr;
FuncPtrs[context] = FuncPtr;

return cast<pi_result>(ret_err);
}
Expand Down Expand Up @@ -641,6 +562,9 @@ static bool is_in_separated_string(const std::string &str, char delimiter,
return false;
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clGetDeviceFunctionPointer_fn)(
cl_device_id device, cl_program program, const char *FuncName,
cl_ulong *ret_ptr);
pi_result piextGetDeviceFunctionPointer(pi_device device, pi_program program,
const char *func_name,
pi_uint64 *function_pointer_ret) {
Expand Down Expand Up @@ -1381,6 +1305,10 @@ pi_result piKernelSetExecInfo(pi_kernel kernel, pi_kernel_exec_info param_name,
}
}

typedef CL_API_ENTRY cl_int(CL_API_CALL *clSetProgramSpecializationConstant_fn)(
cl_program program, cl_uint spec_id, size_t spec_size,
const void *spec_value);

pi_result piextProgramSetSpecializationConstant(pi_program prog,
pi_uint32 spec_id,
size_t spec_size,
Expand Down Expand Up @@ -1456,21 +1384,9 @@ pi_result piextKernelGetNativeHandle(pi_kernel kernel,
// pi_level_zero.cpp for reference) Currently this is just a NOOP.
pi_result piTearDown(void *PluginParameter) {
(void)PluginParameter;
delete ExtFuncsCaches;
ExtFuncsCaches = nullptr;
return PI_SUCCESS;
}

pi_result piContextRelease(pi_context Context) {
{
std::lock_guard<std::mutex> Lock{ExtFuncsCaches->Mtx};

ExtFuncsCaches->Caches.erase(Context);
}

return cast<pi_result>(clReleaseContext(cast<cl_context>(Context)));
}

pi_result piPluginInit(pi_plugin *PluginInit) {
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
if (CompareVersions < 0) {
Expand All @@ -1482,8 +1398,6 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
// PI interface supports higher version or the same version.
strncpy(PluginInit->PluginVersion, SupportedVersion, 4);

ExtFuncsCaches = new ExtFuncsCachesT;

#define _PI_CL(pi_api, ocl_api) \
(PluginInit->PiFunctionTable).pi_api = (decltype(&::pi_api))(&ocl_api);

Expand All @@ -1507,7 +1421,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
_PI_CL(piContextCreate, piContextCreate)
_PI_CL(piContextGetInfo, clGetContextInfo)
_PI_CL(piContextRetain, clRetainContext)
_PI_CL(piContextRelease, piContextRelease)
_PI_CL(piContextRelease, clReleaseContext)
_PI_CL(piextContextGetNativeHandle, piextContextGetNativeHandle)
_PI_CL(piextContextCreateWithNativeHandle, piextContextCreateWithNativeHandle)
// Queue
Expand Down
1 change: 0 additions & 1 deletion sycl/test/abi/pi_opencl_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# UNSUPPORTED: libcxx

piContextCreate
piContextRelease
piDeviceGetInfo
piDevicesGet
piEnqueueMemBufferMap
Expand Down