Skip to content

Commit a0062cf

Browse files
[SYCL] Make ur::getAdapter return raw adapter pointer instead of shared_ptr (#19102)
Currently, `ur::getAdapter` returns a shared pointer to the adapter. However, that's unnecessary. This PR makes `ur::getAdapter` return a raw ptr instead and let global handler manage lifetime of the adapter object.
1 parent a984bee commit a0062cf

File tree

11 files changed

+44
-52
lines changed

11 files changed

+44
-52
lines changed

sycl/source/detail/adapter.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ class Adapter {
239239
UrFuncPtrMapT UrFuncPtrs;
240240
}; // class Adapter
241241

242-
using AdapterPtr = std::shared_ptr<Adapter>;
243-
244242
} // namespace detail
245243
} // namespace _V1
246244
} // namespace sycl

sycl/source/detail/context_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,7 @@ void GetCapabilitiesIntersectionSet(const std::vector<sycl::device> &Devices,
366366
// convenient to be able to reference them without extra `detail::`.
367367
inline auto get_ur_handles(sycl::detail::context_impl &Ctx) {
368368
ur_context_handle_t urCtx = Ctx.getHandleRef();
369-
const sycl::detail::Adapter *Adapter = Ctx.getAdapter().get();
370-
return std::tuple{urCtx, Adapter};
369+
return std::tuple{urCtx, Ctx.getAdapter()};
371370
}
372371
inline auto get_ur_handles(const sycl::context &syclContext) {
373372
return get_ur_handles(*sycl::detail::getSyclObjImpl(syclContext));

sycl/source/detail/global_handler.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ std::mutex &GlobalHandler::getFilterMutex() {
230230
return FilterMutex;
231231
}
232232

233-
std::vector<AdapterPtr> &GlobalHandler::getAdapters() {
234-
static std::vector<AdapterPtr> &adapters = getOrCreate(MAdapters);
233+
std::vector<Adapter *> &GlobalHandler::getAdapters() {
234+
static std::vector<Adapter *> &adapters = getOrCreate(MAdapters);
235235
enableOnCrashStackPrinting();
236236
return adapters;
237237
}
@@ -314,6 +314,7 @@ void GlobalHandler::unloadAdapters() {
314314
if (MAdapters.Inst) {
315315
for (const auto &Adapter : getAdapters()) {
316316
Adapter->release();
317+
delete Adapter;
317318
}
318319
}
319320

@@ -387,6 +388,10 @@ void shutdown_late() {
387388
Handler->MScheduler.Inst.reset(nullptr);
388389
Handler->MProgramManager.Inst.reset(nullptr);
389390

391+
// Cache stores handles to the adapter, so clear it before
392+
// releasing adapters.
393+
Handler->MKernelNameBasedCaches.Inst.reset(nullptr);
394+
390395
// Clear the adapters and reset the instance if it was there.
391396
Handler->unloadAdapters();
392397
if (Handler->MAdapters.Inst)

sycl/source/detail/global_handler.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class ThreadPool;
3030
struct KernelNameBasedCacheT;
3131

3232
using ContextImplPtr = std::shared_ptr<context_impl>;
33-
using AdapterPtr = std::shared_ptr<Adapter>;
3433

3534
/// Wrapper class for global data structures with non-trivial destructors.
3635
///
@@ -71,7 +70,7 @@ class GlobalHandler {
7170
std::mutex &getPlatformToDefaultContextCacheMutex();
7271
std::mutex &getPlatformMapMutex();
7372
std::mutex &getFilterMutex();
74-
std::vector<AdapterPtr> &getAdapters();
73+
std::vector<Adapter *> &getAdapters();
7574
ods_target_list &getOneapiDeviceSelectorTargets(const std::string &InitValue);
7675
XPTIRegistry &getXPTIRegistry();
7776
ThreadPool &getHostTaskThreadPool();
@@ -126,7 +125,7 @@ class GlobalHandler {
126125
InstWithLock<std::mutex> MPlatformToDefaultContextCacheMutex;
127126
InstWithLock<std::mutex> MPlatformMapMutex;
128127
InstWithLock<std::mutex> MFilterMutex;
129-
InstWithLock<std::vector<AdapterPtr>> MAdapters;
128+
InstWithLock<std::vector<Adapter *>> MAdapters;
130129
InstWithLock<ods_target_list> MOneapiDeviceSelectorTargets;
131130
InstWithLock<XPTIRegistry> MXPTIRegistry;
132131
// Thread pool for host task and event callbacks execution

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,21 @@ class KernelProgramCache {
112112
};
113113

114114
struct ProgramBuildResult : public BuildResult<ur_program_handle_t> {
115-
std::weak_ptr<Adapter> AdapterWeakPtr;
116-
ProgramBuildResult(const AdapterPtr &Adapter) : AdapterWeakPtr(Adapter) {
115+
AdapterPtr MAdapter;
116+
ProgramBuildResult(const AdapterPtr &AAdapter) : MAdapter(AAdapter) {
117117
Val = nullptr;
118118
}
119-
ProgramBuildResult(const AdapterPtr &Adapter, BuildState InitialState)
120-
: AdapterWeakPtr(Adapter) {
119+
ProgramBuildResult(const AdapterPtr &AAdapter, BuildState InitialState)
120+
: MAdapter(AAdapter) {
121121
Val = nullptr;
122122
this->State.store(InitialState);
123123
}
124124
~ProgramBuildResult() {
125125
try {
126126
if (Val) {
127-
AdapterPtr AdapterSharedPtr = AdapterWeakPtr.lock();
128-
if (AdapterSharedPtr) {
129-
ur_result_t Err =
130-
AdapterSharedPtr->call_nocheck<UrApiKind::urProgramRelease>(
131-
Val);
132-
__SYCL_CHECK_UR_CODE_NO_EXC(Err, AdapterSharedPtr->getBackend());
133-
}
127+
ur_result_t Err =
128+
MAdapter->call_nocheck<UrApiKind::urProgramRelease>(Val);
129+
__SYCL_CHECK_UR_CODE_NO_EXC(Err, MAdapter->getBackend());
134130
}
135131
} catch (std::exception &e) {
136132
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult",
@@ -202,20 +198,16 @@ class KernelProgramCache {
202198
using KernelArgMaskPairT =
203199
std::pair<ur_kernel_handle_t, const KernelArgMask *>;
204200
struct KernelBuildResult : public BuildResult<KernelArgMaskPairT> {
205-
std::weak_ptr<Adapter> AdapterWeakPtr;
206-
KernelBuildResult(const AdapterPtr &Adapter) : AdapterWeakPtr(Adapter) {
201+
AdapterPtr MAdapter;
202+
KernelBuildResult(const AdapterPtr &AAdapter) : MAdapter(AAdapter) {
207203
Val.first = nullptr;
208204
}
209205
~KernelBuildResult() {
210206
try {
211207
if (Val.first) {
212-
AdapterPtr AdapterSharedPtr = AdapterWeakPtr.lock();
213-
if (AdapterSharedPtr) {
214-
ur_result_t Err =
215-
AdapterSharedPtr->call_nocheck<UrApiKind::urKernelRelease>(
216-
Val.first);
217-
__SYCL_CHECK_UR_CODE_NO_EXC(Err, AdapterSharedPtr->getBackend());
218-
}
208+
ur_result_t Err =
209+
MAdapter->call_nocheck<UrApiKind::urKernelRelease>(Val.first);
210+
__SYCL_CHECK_UR_CODE_NO_EXC(Err, MAdapter->getBackend());
219211
}
220212
} catch (std::exception &e) {
221213
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~KernelBuildResult", e);

sycl/source/detail/platform_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ std::vector<platform> platform_impl::get_platforms() {
168168

169169
// See which platform we want to be served by which adapter.
170170
// There should be just one adapter serving each backend.
171-
std::vector<AdapterPtr> &Adapters = sycl::detail::ur::initializeUr();
171+
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
172172
std::vector<std::pair<platform, AdapterPtr>> PlatformsWithAdapter;
173173

174174
// Then check backend-specific adapters
@@ -504,7 +504,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
504504
// analysis. Doing adjustment by simple copy of last device num from
505505
// previous platform.
506506
// Needs non const adapter reference.
507-
std::vector<AdapterPtr> &Adapters = sycl::detail::ur::initializeUr();
507+
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
508508
auto It = std::find_if(Adapters.begin(), Adapters.end(),
509509
[&Platform = MPlatform](AdapterPtr &Adapter) {
510510
return Adapter->containsUrPlatform(Platform);

sycl/source/detail/platform_impl.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
3939
//
4040
// Platforms can only be created under `GlobalHandler`'s ownership via
4141
// `platform_impl::getOrMakePlatformImpl` method.
42-
explicit platform_impl(ur_platform_handle_t APlatform,
43-
const std::shared_ptr<Adapter> &AAdapter)
42+
explicit platform_impl(ur_platform_handle_t APlatform, Adapter *AAdapter)
4443
: MPlatform(APlatform), MAdapter(AAdapter) {
4544
// Find out backend of the platform
4645
ur_backend_t UrBackend = UR_BACKEND_UNKNOWN;

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
11861186
// nullptr for the mutex.
11871187
auto [Kernel, ArgMask] = BuildF();
11881188
return std::make_shared<FastKernelCacheVal>(
1189-
Kernel, nullptr, ArgMask, Program, *ContextImpl.getAdapter().get());
1189+
Kernel, nullptr, ArgMask, Program, *ContextImpl.getAdapter());
11901190
}
11911191

11921192
auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
@@ -1195,7 +1195,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
11951195
const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val;
11961196
auto ret_val = std::make_shared<FastKernelCacheVal>(
11971197
KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex),
1198-
KernelArgMaskPair.second, Program, *ContextImpl.getAdapter().get());
1198+
KernelArgMaskPair.second, Program, *ContextImpl.getAdapter());
11991199
// If caching is enabled, one copy of the kernel handle will be
12001200
// stored in FastKernelCacheVal, and one is in
12011201
// KernelProgramCache::MKernelsPerProgramCache. To cover

sycl/source/detail/sycl_mem_obj_t.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace detail {
3434
class context_impl;
3535
class event_impl;
3636
class Adapter;
37-
using AdapterPtr = std::shared_ptr<Adapter>;
37+
using AdapterPtr = Adapter *;
3838

3939
using EventImplPtr = std::shared_ptr<event_impl>;
4040

sycl/source/detail/ur.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ bool trace(TraceLevel Level) {
9090
return (TraceLevelMask & Level) == Level;
9191
}
9292

93-
static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
93+
static void initializeAdapters(std::vector<Adapter *> &Adapters,
9494
ur_loader_config_handle_t LoaderConfig);
9595

9696
bool XPTIInitDone = false;
@@ -117,7 +117,7 @@ std::vector<AdapterPtr> &initializeUr(ur_loader_config_handle_t LoaderConfig) {
117117
return GlobalHandler::instance().getAdapters();
118118
}
119119

120-
static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
120+
static void initializeAdapters(std::vector<Adapter *> &Adapters,
121121
ur_loader_config_handle_t LoaderConfig) {
122122
#define CHECK_UR_SUCCESS(Call) \
123123
{ \
@@ -238,7 +238,7 @@ static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
238238
sizeof(adapterBackend), &adapterBackend,
239239
nullptr));
240240
auto syclBackend = UrToSyclBackend(adapterBackend);
241-
Adapters.emplace_back(std::make_shared<Adapter>(UrAdapter, syclBackend));
241+
Adapters.emplace_back(new Adapter(UrAdapter, syclBackend));
242242

243243
const char *env_value = std::getenv("UR_LOG_CALLBACK");
244244
if (env_value == nullptr || std::string(env_value) != "disabled") {
@@ -284,25 +284,25 @@ static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
284284
}
285285

286286
// Get the adapter serving given backend.
287-
template <backend BE> const AdapterPtr &getAdapter() {
288-
static AdapterPtr *Adapter = nullptr;
289-
if (Adapter)
290-
return *Adapter;
287+
template <backend BE> AdapterPtr &getAdapter() {
288+
static AdapterPtr adapterPtr = nullptr;
289+
if (adapterPtr)
290+
return adapterPtr;
291291

292-
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
292+
std::vector<AdapterPtr> Adapters = ur::initializeUr();
293293
for (auto &P : Adapters)
294294
if (P->hasBackend(BE)) {
295-
Adapter = &P;
296-
return *Adapter;
295+
adapterPtr = P;
296+
return adapterPtr;
297297
}
298298

299299
throw exception(errc::runtime, "ur::getAdapter couldn't find adapter");
300300
}
301301

302-
template const AdapterPtr &getAdapter<backend::opencl>();
303-
template const AdapterPtr &getAdapter<backend::ext_oneapi_level_zero>();
304-
template const AdapterPtr &getAdapter<backend::ext_oneapi_cuda>();
305-
template const AdapterPtr &getAdapter<backend::ext_oneapi_hip>();
302+
template AdapterPtr &getAdapter<backend::opencl>();
303+
template AdapterPtr &getAdapter<backend::ext_oneapi_level_zero>();
304+
template AdapterPtr &getAdapter<backend::ext_oneapi_cuda>();
305+
template AdapterPtr &getAdapter<backend::ext_oneapi_hip>();
306306

307307
// Reads an integer value from ELF data.
308308
template <typename ResT>

sycl/source/detail/ur.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ inline namespace _V1 {
2525
enum class backend : char;
2626
namespace detail {
2727
class Adapter;
28-
using AdapterPtr = std::shared_ptr<Adapter>;
28+
using AdapterPtr = Adapter *;
2929

3030
namespace ur {
3131
void *getURLoaderLibrary();
@@ -35,7 +35,7 @@ std::vector<AdapterPtr> &
3535
initializeUr(ur_loader_config_handle_t LoaderConfig = nullptr);
3636

3737
// Get the adapter serving given backend.
38-
template <backend BE> const AdapterPtr &getAdapter();
38+
template <backend BE> AdapterPtr &getAdapter();
3939
} // namespace ur
4040

4141
// Convert from UR backend to SYCL backend enum

0 commit comments

Comments
 (0)