Skip to content

Commit 4ac0638

Browse files
[NFC][SYCL] Move platform::khr_get_default_context impl to platform_impl.cpp (#19157)
Noticed this code as part of the ongoing refactoring to prefer raw ptrs/refs for SYCL RT `*_impl` objects. The usage of this API in `queue_impl` doesn't need to go through user-visible SYCL objects, so move the implementation and use the internal interfrace in `queue_impl` and delegate `platform::khr_get_default_context` to `platform_impl`.
1 parent 0c6ce91 commit 4ac0638

File tree

4 files changed

+33
-27
lines changed

4 files changed

+33
-27
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,23 @@ platform_impl::getPlatformFromUrDevice(ur_device_handle_t UrDevice,
7171
return getOrMakePlatformImpl(Plt, Adapter);
7272
}
7373

74+
context_impl &platform_impl::khr_get_default_context() {
75+
GlobalHandler &GH = GlobalHandler::instance();
76+
// Keeping the default context for platforms in the global cache to avoid
77+
// shared_ptr based circular dependency between platform and context classes
78+
std::unordered_map<platform_impl *, std::shared_ptr<context_impl>>
79+
&PlatformToDefaultContextCache = GH.getPlatformToDefaultContextCache();
80+
81+
std::lock_guard<std::mutex> Lock{GH.getPlatformToDefaultContextCacheMutex()};
82+
83+
auto It = PlatformToDefaultContextCache.find(this);
84+
if (PlatformToDefaultContextCache.end() == It)
85+
std::tie(It, std::ignore) = PlatformToDefaultContextCache.insert(
86+
{this, detail::getSyclObjImpl(context{get_devices()})});
87+
88+
return *It->second;
89+
}
90+
7491
static bool IsBannedPlatform(platform Platform) {
7592
// The NVIDIA OpenCL platform is currently not compatible with DPC++
7693
// since it is only 1.2 but gets selected by default in many systems

sycl/source/detail/platform_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
203203
static platform_impl &getPlatformFromUrDevice(ur_device_handle_t UrDevice,
204204
const AdapterPtr &Adapter);
205205

206+
context_impl &khr_get_default_context();
207+
206208
// when getting sub-devices for ONEAPI_DEVICE_SELECTOR we may temporarily
207209
// ensure every device is a root one.
208210
bool MAlwaysRootDevice = false;

sycl/source/detail/queue_impl.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,17 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
8383
public:
8484
// \return a default context for the platform if it includes the device
8585
// passed and default contexts are enabled, a new context otherwise.
86-
static ContextImplPtr getDefaultOrNew(device_impl &Device) {
87-
if (!SYCLConfig<SYCL_ENABLE_DEFAULT_CONTEXTS>::get())
88-
return detail::getSyclObjImpl(
89-
context{createSyclObjFromImpl<device>(Device), {}, {}});
90-
91-
ContextImplPtr DefaultContext =
92-
detail::getSyclObjImpl(Device.get_platform().khr_get_default_context());
93-
if (DefaultContext->isDeviceValid(Device))
94-
return DefaultContext;
95-
return detail::getSyclObjImpl(
96-
context{createSyclObjFromImpl<device>(Device), {}, {}});
86+
static std::shared_ptr<context_impl> getDefaultOrNew(device_impl &Device) {
87+
if (SYCLConfig<SYCL_ENABLE_DEFAULT_CONTEXTS>::get()) {
88+
context_impl &CtxImpl =
89+
Device.getPlatformImpl().khr_get_default_context();
90+
if (CtxImpl.isDeviceValid(Device))
91+
return CtxImpl.shared_from_this();
92+
}
93+
94+
return context_impl::create(
95+
std::vector<device>{createSyclObjFromImpl<device>(Device)},
96+
async_handler{}, property_list{});
9797
}
9898
/// Constructs a SYCL queue from a device using an async_handler and
9999
/// property_list provided.

sycl/source/platform.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <detail/backend_impl.hpp>
1010
#include <detail/config.hpp>
11+
#include <detail/context_impl.hpp>
1112
#include <detail/global_handler.hpp>
1213
#include <detail/platform_impl.hpp>
1314
#include <detail/ur.hpp>
@@ -89,22 +90,8 @@ platform::get_backend_info() const {
8990
#undef __SYCL_PARAM_TRAITS_SPEC
9091

9192
context platform::khr_get_default_context() const {
92-
// Keeping the default context for platforms in the global cache to avoid
93-
// shared_ptr based circular dependency between platform and context classes
94-
std::unordered_map<detail::platform_impl *, detail::ContextImplPtr>
95-
&PlatformToDefaultContextCache =
96-
detail::GlobalHandler::instance().getPlatformToDefaultContextCache();
97-
98-
std::lock_guard<std::mutex> Lock{
99-
detail::GlobalHandler::instance()
100-
.getPlatformToDefaultContextCacheMutex()};
101-
102-
auto It = PlatformToDefaultContextCache.find(impl.get());
103-
if (PlatformToDefaultContextCache.end() == It)
104-
std::tie(It, std::ignore) = PlatformToDefaultContextCache.insert(
105-
{impl.get(), detail::getSyclObjImpl(context{get_devices()})});
106-
107-
return detail::createSyclObjFromImpl<context>(It->second);
93+
return detail::createSyclObjFromImpl<context>(
94+
impl->khr_get_default_context());
10895
}
10996

11097
context platform::ext_oneapi_get_default_context() const {

0 commit comments

Comments
 (0)