Skip to content

Commit a0c8c50

Browse files
[SYCL] Relax restrictions on non-OpenCL devices during queue creation (#5882)
Co-authored-by: Sergey Semenov <[email protected]>
1 parent f196d0b commit a0c8c50

File tree

3 files changed

+205
-6
lines changed

3 files changed

+205
-6
lines changed

sycl/source/detail/queue_impl.hpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ class queue_impl {
6262

6363
ContextImplPtr DefaultContext = detail::getSyclObjImpl(
6464
Device->get_platform().ext_oneapi_get_default_context());
65-
66-
if (DefaultContext->hasDevice(Device))
65+
if (isValidDevice(DefaultContext, Device))
6766
return DefaultContext;
68-
6967
return detail::getSyclObjImpl(
7068
context{createSyclObjFromImpl<device>(Device), {}, {}});
7169
}
@@ -109,11 +107,20 @@ class queue_impl {
109107
"Queue cannot be constructed with both of "
110108
"discard_events and enable_profiling.");
111109
}
112-
if (!Context->hasDevice(Device))
113-
throw cl::sycl::invalid_object_error(
110+
if (!isValidDevice(Context, Device)) {
111+
if (!Context->is_host() &&
112+
Context->getPlugin().getBackend() == backend::opencl)
113+
throw sycl::invalid_object_error(
114+
"Queue cannot be constructed with the given context and device "
115+
"since the device is not a member of the context (descendants of "
116+
"devices from the context are not supported on OpenCL yet).",
117+
PI_ERROR_INVALID_DEVICE);
118+
throw sycl::invalid_object_error(
114119
"Queue cannot be constructed with the given context and device "
115-
"as the context does not contain the given device.",
120+
"since the device is neither a member of the context nor a "
121+
"descendant of its member.",
116122
PI_ERROR_INVALID_DEVICE);
123+
}
117124
if (!MHostQueue) {
118125
const QueueOrder QOrder =
119126
MPropList.has_property<property::queue::in_order>()
@@ -476,6 +483,27 @@ class queue_impl {
476483
}
477484

478485
private:
486+
/// Helper function for checking whether a device is either a member of a
487+
/// context or a descendnant of its member.
488+
/// \return True iff the device or its parent is a member of the context.
489+
static bool isValidDevice(const ContextImplPtr &Context,
490+
DeviceImplPtr Device) {
491+
// OpenCL does not support creating a queue with a descendant of a device
492+
// from the given context yet.
493+
// TODO remove once this limitation is lifted
494+
if (!Context->is_host() &&
495+
Context->getPlugin().getBackend() == backend::opencl)
496+
return Context->hasDevice(Device);
497+
498+
while (!Context->hasDevice(Device)) {
499+
if (Device->isRootDevice())
500+
return false;
501+
Device = detail::getSyclObjImpl(
502+
Device->get_info<info::device::parent_device>());
503+
}
504+
return true;
505+
}
506+
479507
/// Performs command group submission to the queue.
480508
///
481509
/// \param CGF is a function object containing command group.

sycl/unittests/queue/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_sycl_unittest(QueueTests OBJECT
2+
DeviceCheck.cpp
23
EventClear.cpp
34
USM.cpp
45
Wait.cpp

sycl/unittests/queue/DeviceCheck.cpp

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
//==----------------- DeviceCheck.cpp --- queue unit tests -----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <CL/sycl.hpp>
10+
#include <detail/config.hpp>
11+
#include <detail/device_impl.hpp>
12+
#include <gtest/gtest.h>
13+
#include <helpers/PiMock.hpp>
14+
#include <helpers/ScopedEnvVar.hpp>
15+
16+
using namespace sycl;
17+
18+
namespace {
19+
20+
inline constexpr auto EnableDefaultContextsName =
21+
"SYCL_ENABLE_DEFAULT_CONTEXTS";
22+
23+
pi_result redefinedContextCreate(const pi_context_properties *properties,
24+
pi_uint32 num_devices,
25+
const pi_device *devices,
26+
void (*pfn_notify)(const char *errinfo,
27+
const void *private_info,
28+
size_t cb, void *user_data),
29+
void *user_data, pi_context *ret_context) {
30+
*ret_context = reinterpret_cast<pi_context>(1);
31+
return PI_SUCCESS;
32+
}
33+
34+
pi_result redefinedContextRelease(pi_context context) { return PI_SUCCESS; }
35+
36+
pi_device ParentDevice = nullptr;
37+
pi_platform PiPlatform = nullptr;
38+
39+
pi_result redefinedDeviceGetInfo(pi_device device, pi_device_info param_name,
40+
size_t param_value_size, void *param_value,
41+
size_t *param_value_size_ret) {
42+
if (param_name == PI_DEVICE_INFO_PARTITION_PROPERTIES) {
43+
if (param_value) {
44+
auto *Result =
45+
reinterpret_cast<pi_device_partition_property *>(param_value);
46+
*Result = PI_DEVICE_PARTITION_EQUALLY;
47+
}
48+
if (param_value_size_ret)
49+
*param_value_size_ret = sizeof(pi_device_partition_property);
50+
} else if (param_name == PI_DEVICE_INFO_MAX_COMPUTE_UNITS) {
51+
auto *Result = reinterpret_cast<pi_uint32 *>(param_value);
52+
*Result = 2;
53+
} else if (param_name == PI_DEVICE_INFO_PARENT_DEVICE) {
54+
auto *Result = reinterpret_cast<pi_device *>(param_value);
55+
*Result = (device == ParentDevice) ? nullptr : ParentDevice;
56+
} else if (param_name == PI_DEVICE_INFO_PLATFORM) {
57+
auto *Result = reinterpret_cast<pi_platform *>(param_value);
58+
*Result = PiPlatform;
59+
} else if (param_name == PI_DEVICE_INFO_EXTENSIONS) {
60+
if (param_value_size_ret) {
61+
*param_value_size_ret = 0;
62+
}
63+
}
64+
return PI_SUCCESS;
65+
}
66+
67+
pi_result redefinedDevicePartition(
68+
pi_device device, const pi_device_partition_property *properties,
69+
pi_uint32 num_devices, pi_device *out_devices, pi_uint32 *out_num_devices) {
70+
if (out_devices) {
71+
for (pi_uint32 I = 0; I < num_devices; ++I) {
72+
out_devices[I] = reinterpret_cast<pi_device>(1);
73+
}
74+
}
75+
if (out_num_devices)
76+
*out_num_devices = num_devices;
77+
return PI_SUCCESS;
78+
}
79+
80+
pi_result redefinedDeviceRetain(pi_device device) { return PI_SUCCESS; }
81+
82+
pi_result redefinedDeviceRelease(pi_device device) { return PI_SUCCESS; }
83+
84+
pi_result redefinedQueueCreate(pi_context context, pi_device device,
85+
pi_queue_properties properties,
86+
pi_queue *queue) {
87+
return PI_SUCCESS;
88+
}
89+
90+
pi_result redefinedQueueRelease(pi_queue queue) { return PI_SUCCESS; }
91+
92+
// Check that the device is verified to be either a member of the context or a
93+
// descendant of its member.
94+
TEST(QueueDeviceCheck, CheckDeviceRestriction) {
95+
unittest::ScopedEnvVar EnableDefaultContexts(
96+
EnableDefaultContextsName, "1",
97+
detail::SYCLConfig<detail::SYCL_ENABLE_DEFAULT_CONTEXTS>::reset);
98+
99+
platform Plt{default_selector()};
100+
if (Plt.is_host()) {
101+
std::cout << "The test is not supported on host, skipping" << std::endl;
102+
GTEST_SKIP();
103+
}
104+
PiPlatform = detail::getSyclObjImpl(Plt)->getHandleRef();
105+
// Create default context normally to avoid issues during its release, which
106+
// takes plase after Mock is destroyed.
107+
context DefaultCtx = Plt.ext_oneapi_get_default_context();
108+
device Dev = DefaultCtx.get_devices()[0];
109+
110+
unittest::PiMock Mock{Plt};
111+
Mock.redefine<detail::PiApiKind::piContextCreate>(redefinedContextCreate);
112+
Mock.redefine<detail::PiApiKind::piContextRelease>(redefinedContextRelease);
113+
Mock.redefine<detail::PiApiKind::piDeviceGetInfo>(redefinedDeviceGetInfo);
114+
Mock.redefine<detail::PiApiKind::piDevicePartition>(redefinedDevicePartition);
115+
Mock.redefine<detail::PiApiKind::piDeviceRelease>(redefinedDeviceRelease);
116+
Mock.redefine<detail::PiApiKind::piDeviceRetain>(redefinedDeviceRetain);
117+
Mock.redefine<detail::PiApiKind::piQueueCreate>(redefinedQueueCreate);
118+
Mock.redefine<detail::PiApiKind::piQueueRelease>(redefinedQueueRelease);
119+
120+
// Device is a member of the context.
121+
{
122+
queue Q{Dev};
123+
EXPECT_EQ(Q.get_context().get_platform(), Plt);
124+
EXPECT_EQ(Q.get_context(), DefaultCtx);
125+
queue Q2{DefaultCtx, Dev};
126+
}
127+
// Device is a descendant of a member of the context.
128+
{
129+
ParentDevice = detail::getSyclObjImpl(Dev)->getHandleRef();
130+
std::vector<device> Subdevices =
131+
Dev.create_sub_devices<info::partition_property::partition_equally>(2);
132+
queue Q{Subdevices[0]};
133+
// OpenCL backend does not support using a descendant here yet.
134+
EXPECT_EQ(Q.get_context() == DefaultCtx,
135+
Q.get_backend() != backend::opencl);
136+
try {
137+
queue Q2{DefaultCtx, Subdevices[0]};
138+
EXPECT_NE(Q.get_backend(), backend::opencl);
139+
} catch (sycl::invalid_object_error &e) {
140+
EXPECT_EQ(Q.get_backend(), backend::opencl);
141+
EXPECT_EQ(std::strcmp(
142+
e.what(),
143+
"Queue cannot be constructed with the given context and "
144+
"device since the device is not a member of the context "
145+
"(descendants of devices from the context are not "
146+
"supported on OpenCL yet). -33 (PI_ERROR_INVALID_DEVICE)"),
147+
0);
148+
}
149+
}
150+
// Device is neither of the two.
151+
{
152+
ParentDevice = nullptr;
153+
device Device = detail::createSyclObjFromImpl<device>(
154+
std::make_shared<detail::device_impl>(reinterpret_cast<pi_device>(0x01),
155+
detail::getSyclObjImpl(Plt)));
156+
queue Q{Device};
157+
EXPECT_NE(Q.get_context(), DefaultCtx);
158+
try {
159+
queue Q2{DefaultCtx, Device};
160+
EXPECT_TRUE(false);
161+
} catch (sycl::invalid_object_error &e) {
162+
EXPECT_NE(
163+
std::strstr(e.what(),
164+
"Queue cannot be constructed with the given context and "
165+
"device"),
166+
nullptr);
167+
}
168+
}
169+
}
170+
} // anonymous namespace

0 commit comments

Comments
 (0)