Skip to content

Commit aa2c876

Browse files
[SYCL][NFCI] Ensure device_impl is only created via platform_impl::getOrMakeDeviceImpl (#18227)
Once that is guaranteed, it can enable further refactoring in subsequent PRs.
1 parent e7c85ed commit aa2c876

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

sycl/source/detail/device_impl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace detail {
2121

2222
/// Constructs a SYCL device instance using the provided
2323
/// UR device instance.
24-
device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform)
24+
device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform,
25+
device_impl::private_tag)
2526
: MDevice(Device), MPlatform(Platform.shared_from_this()),
2627
MDeviceHostBaseTime(std::make_pair(0, 0)) {
2728
const AdapterPtr &Adapter = Platform.getAdapter();

sycl/source/detail/device_impl.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,19 @@ class platform_impl;
3333

3434
// TODO: Make code thread-safe
3535
class device_impl {
36+
struct private_tag {
37+
explicit private_tag() = default;
38+
};
39+
friend class platform_impl;
40+
3641
public:
3742
/// Constructs a SYCL device instance using the provided
3843
/// UR device instance.
39-
explicit device_impl(ur_device_handle_t Device, platform_impl &Platform);
44+
//
45+
// Must be called through `platform_impl::getOrMakeDeviceImpl` only.
46+
// `private_tag` ensures that is true.
47+
explicit device_impl(ur_device_handle_t Device, platform_impl &Platform,
48+
private_tag);
4049

4150
~device_impl();
4251

sycl/source/detail/platform_impl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ platform_impl::getOrMakeDeviceImpl(ur_device_handle_t UrDevice) {
304304
return Result;
305305

306306
// Otherwise make the impl
307-
Result = std::make_shared<device_impl>(UrDevice, *this);
307+
Result = std::make_shared<device_impl>(UrDevice, *this,
308+
device_impl::private_tag{});
308309
MDeviceCache.emplace_back(Result);
309310

310311
return Result;

sycl/unittests/program_manager/SubDevices.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,8 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) {
106106
rootDevice = sycl::detail::getSyclObjImpl(device)->getHandleRef();
107107
// Initialize sub-devices
108108
sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt);
109-
auto subDev1 =
110-
std::make_shared<sycl::detail::device_impl>(urSubDev1, PltImpl);
111-
auto subDev2 =
112-
std::make_shared<sycl::detail::device_impl>(urSubDev2, PltImpl);
109+
auto subDev1 = PltImpl.getOrMakeDeviceImpl(urSubDev1);
110+
auto subDev2 = PltImpl.getOrMakeDeviceImpl(urSubDev2);
113111
sycl::context Ctx{
114112
{device, sycl::detail::createSyclObjFromImpl<sycl::device>(subDev1),
115113
sycl::detail::createSyclObjFromImpl<sycl::device>(subDev2)}};

sycl/unittests/queue/DeviceCheck.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ ur_result_t redefinedDevicePartitionAfter(void *pParams) {
6262
**params.ppNumDevicesRet = *params.pNumDevices;
6363
return UR_RESULT_SUCCESS;
6464
}
65+
ur_result_t redefinedPlatformGet(void *pParams) {
66+
auto params = reinterpret_cast<ur_platform_get_params_t *>(pParams);
67+
if (*params->ppNumPlatforms)
68+
**params->ppNumPlatforms = 2;
69+
70+
if (*params->pphPlatforms && *params->pNumEntries > 0) {
71+
(*params->pphPlatforms)[0] = reinterpret_cast<ur_platform_handle_t>(1);
72+
(*params->pphPlatforms)[1] = reinterpret_cast<ur_platform_handle_t>(2);
73+
}
74+
75+
return UR_RESULT_SUCCESS;
76+
}
6577

6678
// Check that the device is verified to be either a member of the context or a
6779
// descendant of its member.
@@ -71,6 +83,8 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) {
7183
detail::SYCLConfig<detail::SYCL_ENABLE_DEFAULT_CONTEXTS>::reset);
7284

7385
sycl::unittest::UrMock<> Mock;
86+
mock::getCallbacks().set_replace_callback("urPlatformGet",
87+
&redefinedPlatformGet);
7488
sycl::platform Plt = sycl::platform();
7589

7690
UrPlatform = detail::getSyclObjImpl(Plt)->getHandleRef();
@@ -116,12 +130,15 @@ TEST(QueueDeviceCheck, CheckDeviceRestriction) {
116130
// Device is neither of the two.
117131
{
118132
ParentDevice = nullptr;
119-
device Device = detail::createSyclObjFromImpl<device>(
120-
std::make_shared<detail::device_impl>(
121-
reinterpret_cast<ur_device_handle_t>(0x01),
122-
*detail::getSyclObjImpl(Plt)));
133+
134+
auto Plts = sycl::platform::get_platforms();
135+
EXPECT_TRUE(Plts.size() == 2);
136+
sycl::platform OtherPlt = Plts[1];
137+
138+
device Device = OtherPlt.get_devices()[0];
123139
queue Q{Device};
124-
EXPECT_NE(Q.get_context(), DefaultCtx);
140+
auto Ctx = Q.get_context();
141+
EXPECT_NE(Ctx, DefaultCtx);
125142
try {
126143
queue Q2{DefaultCtx, Device};
127144
EXPECT_TRUE(false);

0 commit comments

Comments
 (0)