Skip to content

Commit b21e74e

Browse files
[SYCL] sub-devices selected by ONEAPI_DEVICE_SELECTOR as root devices (#7167)
All devices available when using ONEAPI_DEVICE_SELECTOR are root devices, even those which are gotten via the sub-device selection choices ( e.g. `ONEAPI_DEVICE_SELECTOR=level_zero:*.*` ). In this PR we are ensuring that those devices pretend to be root devices. Tests for this are here: intel/llvm-test-suite#1346
1 parent 5dc011f commit b21e74e

File tree

6 files changed

+117
-12
lines changed

6 files changed

+117
-12
lines changed

sycl/include/sycl/detail/util.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ class Sync {
3232
std::mutex GlobalLock;
3333
};
3434

35+
// TempAssignGuard is the class for a guard object that will assign some OTHER
36+
// variable to a temporary value but restore it when the guard itself goes out
37+
// of scope.
38+
template <typename T> struct TempAssignGuard {
39+
T &field;
40+
T restoreValue;
41+
TempAssignGuard(T &fld, T tempVal) : field(fld), restoreValue(fld) {
42+
field = tempVal;
43+
}
44+
~TempAssignGuard() { field = restoreValue; }
45+
};
46+
3547
// const char* key hash for STL maps
3648
struct HashCStr {
3749
size_t operator()(const char *S) const {

sycl/source/detail/device_impl.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ device_impl::device_impl(pi_native_handle InteropDeviceHandle,
5353
Plugin.call<PiApiKind::piDeviceGetInfo>(
5454
MDevice, PI_DEVICE_INFO_TYPE, sizeof(RT::PiDeviceType), &MType, nullptr);
5555

56-
// TODO catch an exception and put it to list of asynchronous exceptions
57-
Plugin.call<PiApiKind::piDeviceGetInfo>(MDevice, PI_DEVICE_INFO_PARENT_DEVICE,
58-
sizeof(RT::PiDevice), &MRootDevice,
59-
nullptr);
56+
// No need to set MRootDevice when MAlwaysRootDevice is true
57+
if ((Platform == nullptr) || !Platform->MAlwaysRootDevice) {
58+
// TODO catch an exception and put it to list of asynchronous exceptions
59+
Plugin.call<PiApiKind::piDeviceGetInfo>(
60+
MDevice, PI_DEVICE_INFO_PARENT_DEVICE, sizeof(RT::PiDevice),
61+
&MRootDevice, nullptr);
62+
}
6063

6164
if (!InteroperabilityConstructor) {
6265
// TODO catch an exception and put it to list of asynchronous exceptions

sycl/source/detail/platform_impl.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <detail/platform_impl.hpp>
1515
#include <detail/platform_info.hpp>
1616
#include <sycl/detail/iostream_proxy.hpp>
17+
#include <sycl/detail/util.hpp>
1718
#include <sycl/device.hpp>
1819

1920
#include <algorithm>
@@ -253,13 +254,17 @@ static bool supportsPartitionProperty(const device &dev,
253254

254255
static std::vector<device> amendDeviceAndSubDevices(
255256
backend PlatformBackend, std::vector<device> &DeviceList,
256-
ods_target_list *OdsTargetList, int PlatformDeviceIndex) {
257+
ods_target_list *OdsTargetList, int PlatformDeviceIndex,
258+
PlatformImplPtr PlatformImpl) {
257259
constexpr info::partition_property partitionProperty =
258260
info::partition_property::partition_by_affinity_domain;
259261
constexpr info::partition_affinity_domain affinityDomain =
260262
info::partition_affinity_domain::next_partitionable;
261263

262264
std::vector<device> FinalResult;
265+
// (Only) when amending sub-devices for ONEAPI_DEVICE_SELECTOR, all
266+
// sub-devices are treated as root.
267+
TempAssignGuard<bool> TAG(PlatformImpl->MAlwaysRootDevice, true);
263268

264269
for (unsigned i = 0; i < DeviceList.size(); i++) {
265270
// device has already been screened. The question is whether it should be a
@@ -311,9 +316,8 @@ static std::vector<device> amendDeviceAndSubDevices(
311316
// -- Add sub sub device.
312317
if (wantSubSubDevice) {
313318

314-
auto subDevicesToPartition = dev.create_sub_devices<
315-
info::partition_property::partition_by_affinity_domain>(
316-
affinityDomain);
319+
auto subDevicesToPartition =
320+
dev.create_sub_devices<partitionProperty>(affinityDomain);
317321
if (target.SubDeviceNum) {
318322
if (subDevicesToPartition.size() >
319323
target.SubDeviceNum.value()) {
@@ -341,9 +345,9 @@ static std::vector<device> amendDeviceAndSubDevices(
341345
continue;
342346
}
343347
// Allright, lets get them sub-sub-devices.
344-
auto subSubDevices = subDev.create_sub_devices<
345-
info::partition_property::partition_by_affinity_domain>(
346-
affinityDomain);
348+
auto subSubDevices =
349+
subDev.create_sub_devices<partitionProperty>(
350+
affinityDomain);
347351
if (target.HasSubSubDeviceWildCard) {
348352
FinalResult.insert(FinalResult.end(), subSubDevices.begin(),
349353
subSubDevices.end());
@@ -476,7 +480,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
476480
// Otherwise, our last step is to revisit the devices, possibly replacing
477481
// them with subdevices (which have been ignored until now)
478482
return amendDeviceAndSubDevices(Backend, Res, OdsTargetList,
479-
PlatformDeviceIndex);
483+
PlatformDeviceIndex, PlatformImpl);
480484
}
481485

482486
bool platform_impl::has_extension(const std::string &ExtensionName) const {

sycl/source/detail/platform_impl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ class platform_impl {
188188
static std::shared_ptr<platform_impl>
189189
getPlatformFromPiDevice(RT::PiDevice PiDevice, const plugin &Plugin);
190190

191+
// when getting sub-devices for ONEAPI_DEVICE_SELECTOR we may temporarily
192+
// ensure every device is a root one.
193+
bool MAlwaysRootDevice = false;
194+
191195
private:
192196
std::shared_ptr<device_impl> getDeviceImplHelper(RT::PiDevice PiDevice);
193197

sycl/source/device.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,19 @@ device::get_info() const {
142142
return impl->template get_info<Param>();
143143
}
144144

145+
template <> device device::get_info<info::device::parent_device>() const {
146+
// With ONEAPI_DEVICE_SELECTOR the impl.MRootDevice is preset and may be
147+
// overridden (ie it may be nullptr on a sub-device) The PI of the sub-devices
148+
// have parents, but we don't want to return them. They must pretend to be
149+
// parentless root devices.
150+
if (impl->isRootDevice())
151+
throw invalid_object_error(
152+
"No parent for device because it is not a subdevice",
153+
PI_ERROR_INVALID_DEVICE);
154+
else
155+
return impl->template get_info<info::device::parent_device>();
156+
}
157+
145158
#define __SYCL_PARAM_TRAITS_SPEC(DescType, Desc, ReturnT, PiCode) \
146159
template __SYCL_EXPORT ReturnT device::get_info<info::device::Desc>() const;
147160

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
#include <sycl/detail/util.hpp>
5+
#include <sycl/sycl.hpp>
6+
using namespace sycl;
7+
8+
struct someStruct {
9+
int firstValue;
10+
bool secondValue;
11+
};
12+
13+
int main() {
14+
someStruct myStruct;
15+
myStruct.firstValue = 2;
16+
myStruct.secondValue = false;
17+
someStruct moarStruct;
18+
moarStruct.firstValue = 3;
19+
moarStruct.secondValue = false;
20+
someStruct *moarPtr = &moarStruct;
21+
int anotherValue = 4;
22+
23+
{ // Scope to limit lifetime of TempAssignGuards.
24+
25+
sycl::detail::TempAssignGuard myTAG_1(myStruct.firstValue, -20);
26+
sycl::detail::TempAssignGuard myTAG_2(myStruct.secondValue, true);
27+
sycl::detail::TempAssignGuard moarTAG_1(moarPtr->firstValue, -30);
28+
sycl::detail::TempAssignGuard moarTAG_2(moarPtr->secondValue, true);
29+
sycl::detail::TempAssignGuard anotherTAG(anotherValue, -40);
30+
31+
// Ensure values have been temporarily assigned.
32+
assert(myStruct.firstValue == -20);
33+
assert(myStruct.secondValue == true);
34+
assert(moarStruct.firstValue == -30);
35+
assert(moarStruct.secondValue == true);
36+
assert(anotherValue == -40);
37+
}
38+
39+
// Ensure values have been restored.
40+
assert(myStruct.firstValue == 2);
41+
assert(myStruct.secondValue == false);
42+
assert(moarStruct.firstValue == 3);
43+
assert(moarStruct.secondValue == false);
44+
assert(anotherValue == 4);
45+
46+
// Test exceptions
47+
int exceptionalValue = 5;
48+
try {
49+
sycl::detail::TempAssignGuard exceptionalTAG(exceptionalValue, -50);
50+
assert(exceptionalValue == -50);
51+
throw 7; // Baby needs a new pair of shoes.
52+
} catch (...) {
53+
assert(exceptionalValue == 5);
54+
}
55+
assert(exceptionalValue == 5);
56+
57+
// Test premature exit
58+
int prematureValue = 6;
59+
{
60+
sycl::detail::TempAssignGuard prematureTAG(prematureValue, -60);
61+
assert(prematureValue == -60);
62+
goto dragons;
63+
assert(true == false);
64+
}
65+
dragons:
66+
assert(prematureValue == 6);
67+
68+
return 0;
69+
}

0 commit comments

Comments
 (0)