Skip to content

Commit a26c4fe

Browse files
[SYCL] Less shared_ptr for platform_impl
1 parent 19668be commit a26c4fe

23 files changed

+101
-111
lines changed

sycl/gdb/libsycl.so-gdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class SYCLDevice(SYCLValue):
376376

377377
IMPL_OFFSET_TO_DEVICE_TYPE = 0x8
378378
IMPL_OFFSET_TO_PLATFORM = 0x18
379-
PLATFORM_OFFSET_TO_BACKEND = 0x10
379+
PLATFORM_OFFSET_TO_BACKEND = 0x20
380380

381381
def __init__(self, gdb_value):
382382
super().__init__(gdb_value)

sycl/include/sycl/detail/impl_utils.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ T createSyclObjFromImpl(
5050
return T(ImplObj);
5151
}
5252

53+
template <class T>
54+
T createSyclObjFromImpl(
55+
std::add_lvalue_reference_t<typename std::remove_reference_t<
56+
decltype(getSyclObjImpl(std::declval<T>()))>::element_type>
57+
ImplRef) {
58+
return createSyclObjFromImpl<T>(ImplRef.shared_from_this());
59+
}
60+
5361
} // namespace detail
5462
} // namespace _V1
5563
} // namespace sycl

sycl/source/backend.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
8989
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
9090

9191
// Construct the SYCL device from UR device.
92-
auto Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter);
9392
return detail::createSyclObjFromImpl<device>(
94-
Platform->getOrMakeDeviceImpl(UrDevice, Platform));
93+
platform_impl::getPlatformFromUrDevice(UrDevice, Adapter)
94+
.getOrMakeDeviceImpl(UrDevice));
9595
}
9696

9797
__SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
@@ -288,10 +288,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
288288
std::transform(
289289
ProgramDevices.begin(), ProgramDevices.end(), std::back_inserter(Devices),
290290
[&Adapter](const auto &Dev) {
291-
auto Platform =
292-
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter);
293-
auto DeviceImpl = Platform->getOrMakeDeviceImpl(Dev, Platform);
294-
return createSyclObjFromImpl<device>(DeviceImpl);
291+
return createSyclObjFromImpl<device>(
292+
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter)
293+
.getOrMakeDeviceImpl(Dev));
295294
});
296295

297296
// Unlike SYCL, other backends, like OpenCL or Level Zero, may not support

sycl/source/backend/level_zero.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@ using namespace sycl::detail;
2020
__SYCL_EXPORT device make_device(const platform &Platform,
2121
ur_native_handle_t NativeHandle) {
2222
const auto &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
23-
const auto &PlatformImpl = getSyclObjImpl(Platform);
2423
// Create UR device first.
2524
ur_device_handle_t UrDevice;
2625
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
2726
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
2827

2928
return detail::createSyclObjFromImpl<device>(
30-
PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl));
29+
getSyclObjImpl(Platform)->getOrMakeDeviceImpl(UrDevice));
3130
}
3231

3332
} // namespace ext::oneapi::level_zero::detail

sycl/source/detail/allowlist.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,9 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
375375

376376
// Get platform's backend and put it to DeviceDesc
377377
DeviceDescT DeviceDesc;
378-
auto PlatformImpl = platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
379-
backend Backend = PlatformImpl->getBackend();
378+
platform_impl &PlatformImpl =
379+
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
380+
backend Backend = PlatformImpl.getBackend();
380381

381382
for (const auto &SyclBe : getSyclBeMap()) {
382383
if (SyclBe.second == Backend) {
@@ -395,7 +396,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
395396

396397
int InsertIDx = 0;
397398
for (ur_device_handle_t Device : UrDevices) {
398-
auto DeviceImpl = PlatformImpl->getOrMakeDeviceImpl(Device, PlatformImpl);
399+
auto DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device);
399400
// get DeviceType value and put it to DeviceDesc
400401
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
401402
Adapter->call<UrApiKind::urDeviceGetInfo>(

sycl/source/detail/buffer_impl.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,11 @@ buffer_impl::getNativeVector(backend BackendName) const {
7979
// doesn't have context and platform
8080
if (!Ctx)
8181
continue;
82-
const PlatformImplPtr &Platform = Ctx->getPlatformImpl();
83-
assert(Platform && "Platform must be present for device context");
84-
if (Platform->getBackend() != BackendName)
82+
const platform_impl &Platform = Ctx->getPlatformImpl();
83+
if (Platform.getBackend() != BackendName)
8584
continue;
8685

87-
auto Adapter = Platform->getAdapter();
86+
auto Adapter = Platform.getAdapter();
8887

8988
ur_native_handle_t Handle = 0;
9089
// When doing buffer interop we don't know what device the memory should be
@@ -94,7 +93,7 @@ buffer_impl::getNativeVector(backend BackendName) const {
9493
&Handle);
9594
Handles.push_back(Handle);
9695

97-
if (Platform->getBackend() == backend::opencl) {
96+
if (Platform.getBackend() == backend::opencl) {
9897
__SYCL_OCL_CALL(clRetainMemObject, ur::cast<cl_mem>(Handle));
9998
}
10099
}

sycl/source/detail/context_impl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
4141
async_handler AsyncHandler,
4242
const property_list &PropList)
4343
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices),
44-
MContext(nullptr), MPlatform(), MPropList(PropList),
45-
MSupportBufferLocationByDevices(NotChecked) {
44+
MContext(nullptr),
45+
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())),
46+
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
4647
verifyProps(PropList);
47-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
4848
std::vector<ur_device_handle_t> DeviceIds;
4949
for (const auto &D : MDevices) {
5050
if (D.has(aspect::ext_oneapi_is_composite)) {
@@ -96,13 +96,13 @@ context_impl::context_impl(ur_context_handle_t UrContext,
9696
make_error_code(errc::invalid),
9797
"No devices in the provided device list and native context.");
9898

99-
std::shared_ptr<detail::platform_impl> Platform =
99+
platform_impl &Platform =
100100
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
101101
for (ur_device_handle_t Dev : DeviceIds) {
102-
MDevices.emplace_back(createSyclObjFromImpl<device>(
103-
Platform->getOrMakeDeviceImpl(Dev, Platform)));
102+
MDevices.emplace_back(
103+
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(Dev)));
104104
}
105-
MPlatform = Platform;
105+
MPlatform = Platform.shared_from_this();
106106
}
107107
// TODO catch an exception and put it to list of asynchronous exceptions
108108
// getAdapter() will be the same as the Adapter passed. This should be taken
@@ -158,7 +158,7 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
158158
this->getAdapter());
159159
}
160160
template <> platform context_impl::get_info<info::context::platform>() const {
161-
return createSyclObjFromImpl<platform>(MPlatform);
161+
return createSyclObjFromImpl<platform>(*MPlatform);
162162
}
163163
template <>
164164
std::vector<sycl::device>

sycl/source/detail/context_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ inline namespace _V1 {
2929
// Forward declaration
3030
class device;
3131
namespace detail {
32-
using PlatformImplPtr = std::shared_ptr<detail::platform_impl>;
3332
class context_impl {
3433
public:
3534
/// Constructs a context_impl using a single SYCL devices.
@@ -90,7 +89,7 @@ class context_impl {
9089
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
9190

9291
/// \return the PlatformImpl associated with this context.
93-
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
92+
platform_impl &getPlatformImpl() const { return *MPlatform; }
9493

9594
/// Queries this context for information.
9695
///
@@ -257,7 +256,8 @@ class context_impl {
257256
async_handler MAsyncHandler;
258257
std::vector<device> MDevices;
259258
ur_context_handle_t MContext;
260-
PlatformImplPtr MPlatform;
259+
// TODO: Make it a reference instead, but that needs a bit more refactoring:
260+
std::shared_ptr<platform_impl> MPlatform;
261261
property_list MPropList;
262262
CachedLibProgramsT MCachedLibPrograms;
263263
std::mutex MCachedLibProgramsMutex;

sycl/source/detail/device_impl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ namespace detail {
2121

2222
/// Constructs a SYCL device instance using the provided
2323
/// UR device instance.
24-
device_impl::device_impl(ur_device_handle_t Device, PlatformImplPtr Platform)
25-
: MDevice(Device), MPlatform(Platform),
24+
device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform)
25+
: MDevice(Device), MPlatform(Platform.shared_from_this()),
2626
MDeviceHostBaseTime(std::make_pair(0, 0)) {
27-
const AdapterPtr &Adapter = Platform->getAdapter();
27+
const AdapterPtr &Adapter = Platform.getAdapter();
2828

2929
// TODO catch an exception and put it to list of asynchronous exceptions
3030
Adapter->call<UrApiKind::urDeviceGetInfo>(
3131
MDevice, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t), &MType, nullptr);
3232

3333
// No need to set MRootDevice when MAlwaysRootDevice is true
34-
if (!Platform->MAlwaysRootDevice) {
34+
if (!Platform.MAlwaysRootDevice) {
3535
// TODO catch an exception and put it to list of asynchronous exceptions
3636
Adapter->call<UrApiKind::urDeviceGetInfo>(
3737
MDevice, UR_DEVICE_INFO_PARENT_DEVICE, sizeof(ur_device_handle_t),
@@ -177,7 +177,7 @@ std::vector<device> device_impl::create_sub_devices(
177177
std::for_each(SubDevices.begin(), SubDevices.end(),
178178
[&res, this](const ur_device_handle_t &a_ur_device) {
179179
device sycl_device = detail::createSyclObjFromImpl<device>(
180-
MPlatform->getOrMakeDeviceImpl(a_ur_device, MPlatform));
180+
MPlatform->getOrMakeDeviceImpl(a_ur_device));
181181
res.push_back(sycl_device);
182182
});
183183
return res;

sycl/source/detail/device_impl.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@ namespace detail {
3030

3131
// Forward declaration
3232
class platform_impl;
33-
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3433

3534
// TODO: Make code thread-safe
3635
class device_impl {
3736
public:
3837
/// Constructs a SYCL device instance using the provided
3938
/// UR device instance.
40-
explicit device_impl(ur_device_handle_t Device, PlatformImplPtr Platform);
39+
explicit device_impl(ur_device_handle_t Device, platform_impl &Platform);
4140

4241
~device_impl();
4342

@@ -279,8 +278,7 @@ class device_impl {
279278
backend getBackend() const { return MPlatform->getBackend(); }
280279

281280
/// @brief Get the platform impl serving this device
282-
/// @return PlatformImplPtr
283-
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
281+
platform_impl &getPlatformImpl() const { return *MPlatform; }
284282

285283
/// Get device info string
286284
std::string get_device_info_string(ur_device_info_t InfoCode) const;
@@ -292,7 +290,7 @@ class device_impl {
292290
ur_device_handle_t MDevice = 0;
293291
ur_device_type_t MType;
294292
ur_device_handle_t MRootDevice = nullptr;
295-
PlatformImplPtr MPlatform;
293+
std::shared_ptr<platform_impl> MPlatform;
296294
bool MUseNativeAssert = false;
297295
mutable std::string MDeviceName;
298296
mutable std::once_flag MDeviceNameFlag;

sycl/source/detail/device_info.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
namespace sycl {
3636
inline namespace _V1 {
3737
namespace detail {
38-
3938
inline std::vector<memory_order>
4039
readMemoryOrderBitfield(ur_memory_order_capability_flags_t bits) {
4140
std::vector<memory_order> result;
@@ -1171,9 +1170,8 @@ template <> struct get_device_info_impl<device, info::device::parent_device> {
11711170
throw exception(make_error_code(errc::invalid),
11721171
"No parent for device because it is not a subdevice");
11731172

1174-
const auto &Platform = Dev.getPlatformImpl();
11751173
return createSyclObjFromImpl<device>(
1176-
Platform->getOrMakeDeviceImpl(result, Platform));
1174+
Dev.getPlatformImpl().getOrMakeDeviceImpl(result));
11771175
}
11781176
};
11791177

@@ -1337,10 +1335,10 @@ struct get_device_info_impl<
13371335
ext::oneapi::experimental::info::device::component_devices>::value,
13381336
ResultSize, Devs.data(), nullptr);
13391337
std::vector<sycl::device> Result;
1340-
const auto &Platform = Dev.getPlatformImpl();
1338+
platform_impl &Platform = Dev.getPlatformImpl();
13411339
for (const auto &d : Devs)
1342-
Result.push_back(createSyclObjFromImpl<device>(
1343-
Platform->getOrMakeDeviceImpl(d, Platform)));
1340+
Result.push_back(
1341+
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(d)));
13441342

13451343
return Result;
13461344
}
@@ -1363,9 +1361,8 @@ struct get_device_info_impl<
13631361
sizeof(Result), &Result, nullptr);
13641362

13651363
if (Result) {
1366-
const auto &Platform = Dev.getPlatformImpl();
13671364
return createSyclObjFromImpl<device>(
1368-
Platform->getOrMakeDeviceImpl(Result, Platform));
1365+
Dev.getPlatformImpl().getOrMakeDeviceImpl(Result));
13691366
}
13701367
throw sycl::exception(make_error_code(errc::invalid),
13711368
"A component with aspect::ext_oneapi_is_component "

sycl/source/detail/global_handler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ ProgramManager &GlobalHandler::getProgramManager() {
184184
return PM;
185185
}
186186

187-
std::unordered_map<PlatformImplPtr, ContextImplPtr> &
187+
std::unordered_map<platform_impl *, ContextImplPtr> &
188188
GlobalHandler::getPlatformToDefaultContextCache() {
189189
// The optimization with static reference is not done because
190190
// there are public methods of the GlobalHandler
@@ -205,8 +205,8 @@ Sync &GlobalHandler::getSync() {
205205
return sync;
206206
}
207207

208-
std::vector<PlatformImplPtr> &GlobalHandler::getPlatformCache() {
209-
static std::vector<PlatformImplPtr> &PlatformCache =
208+
std::vector<std::shared_ptr<platform_impl>> &GlobalHandler::getPlatformCache() {
209+
static std::vector<std::shared_ptr<platform_impl>> &PlatformCache =
210210
getOrCreate(MPlatformCache);
211211
return PlatformCache;
212212
}

sycl/source/detail/global_handler.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class ods_target_list;
2727
class XPTIRegistry;
2828
class ThreadPool;
2929

30-
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3130
using ContextImplPtr = std::shared_ptr<context_impl>;
3231
using AdapterPtr = std::shared_ptr<Adapter>;
3332

@@ -60,9 +59,9 @@ class GlobalHandler {
6059
bool isSchedulerAlive() const;
6160
ProgramManager &getProgramManager();
6261
Sync &getSync();
63-
std::vector<PlatformImplPtr> &getPlatformCache();
62+
std::vector<std::shared_ptr<platform_impl>> &getPlatformCache();
6463

65-
std::unordered_map<PlatformImplPtr, ContextImplPtr> &
64+
std::unordered_map<platform_impl *, ContextImplPtr> &
6665
getPlatformToDefaultContextCache();
6766

6867
std::mutex &getPlatformToDefaultContextCacheMutex();
@@ -117,8 +116,8 @@ class GlobalHandler {
117116
InstWithLock<Scheduler> MScheduler;
118117
InstWithLock<ProgramManager> MProgramManager;
119118
InstWithLock<Sync> MSync;
120-
InstWithLock<std::vector<PlatformImplPtr>> MPlatformCache;
121-
InstWithLock<std::unordered_map<PlatformImplPtr, ContextImplPtr>>
119+
InstWithLock<std::vector<std::shared_ptr<platform_impl>>> MPlatformCache;
120+
InstWithLock<std::unordered_map<platform_impl *, ContextImplPtr>>
122121
MPlatformToDefaultContextCache;
123122
InstWithLock<std::mutex> MPlatformToDefaultContextCacheMutex;
124123
InstWithLock<std::mutex> MPlatformMapMutex;

sycl/source/detail/kernel_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void kernel_impl::checkIfValidForNumArgsInfoQuery() const {
126126
}
127127

128128
void kernel_impl::enableUSMIndirectAccess() const {
129-
if (!MContext->getPlatformImpl()->supports_usm())
129+
if (!MContext->getPlatformImpl().supports_usm())
130130
return;
131131

132132
// Some UR Adapters (like OpenCL) require this call to enable USM

0 commit comments

Comments
 (0)