Skip to content

Commit 7f2f668

Browse files
authored
[SYCL] Add support for platform name and platform version in device whitelist (#890)
PlatformName and PlatformVersion are now supported keys. Also all fields are optional now, so SYCL_DEVICE_WHITE_LIST="" matches any device Signed-off-by: Vlad Romanov <[email protected]>
1 parent 07a3616 commit 7f2f668

File tree

2 files changed

+95
-28
lines changed

2 files changed

+95
-28
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ struct DevDescT {
5858

5959
const char *devDriverVer = nullptr;
6060
int devDriverVerSize = 0;
61+
62+
const char *platformName = nullptr;
63+
int platformNameSize = 0;
64+
65+
const char *platformVer = nullptr;
66+
int platformVerSize = 0;
6167
};
6268

6369
static std::vector<DevDescT> getWhiteListDesc() {
@@ -68,6 +74,8 @@ static std::vector<DevDescT> getWhiteListDesc() {
6874
std::vector<DevDescT> decDescs;
6975
const char devNameStr[] = "DeviceName";
7076
const char driverVerStr[] = "DriverVersion";
77+
const char platformNameStr[] = "PlatformName";
78+
const char platformVerStr[] = "PlatformVersion";
7179
decDescs.emplace_back();
7280
while ('\0' != *str) {
7381
const char **valuePtr = nullptr;
@@ -78,6 +86,15 @@ static std::vector<DevDescT> getWhiteListDesc() {
7886
valuePtr = &decDescs.back().devName;
7987
size = &decDescs.back().devNameSize;
8088
str += sizeof(devNameStr) - 1;
89+
} else if (0 ==
90+
strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) {
91+
valuePtr = &decDescs.back().platformName;
92+
size = &decDescs.back().platformNameSize;
93+
str += sizeof(platformNameStr) - 1;
94+
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
95+
valuePtr = &decDescs.back().platformVer;
96+
size = &decDescs.back().platformVerSize;
97+
str += sizeof(platformVerStr) - 1;
8198
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
8299
valuePtr = &decDescs.back().devDriverVer;
83100
size = &decDescs.back().devDriverVerSize;
@@ -125,23 +142,43 @@ static std::vector<DevDescT> getWhiteListDesc() {
125142
return decDescs;
126143
}
127144

128-
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
145+
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices,
146+
RT::PiPlatform pi_platform) {
129147
const std::vector<DevDescT> whiteList(getWhiteListDesc());
130148
if (whiteList.empty())
131149
return;
132150

151+
const string_class platformName =
152+
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
153+
pi_platform);
154+
155+
const string_class platformVer = sycl::detail::get_platform_info<
156+
string_class, info::platform::version>::get(pi_platform);
157+
133158
int insertIDx = 0;
134159
for (RT::PiDevice dev : pi_devices) {
135160
const string_class devName =
136-
sycl::detail::get_device_info<string_class, info::device::name>::get(dev);
161+
sycl::detail::get_device_info<string_class, info::device::name>::get(
162+
dev);
137163

138164
const string_class devDriverVer =
139165
sycl::detail::get_device_info<string_class,
140166
info::device::driver_version>::get(dev);
141167

142168
for (const DevDescT &desc : whiteList) {
143-
// At least device name is required field to consider the filter so far
144-
if (nullptr == desc.devName ||
169+
if (nullptr != desc.platformName &&
170+
!std::regex_match(platformName,
171+
std::regex(std::string(desc.platformName,
172+
desc.platformNameSize))))
173+
continue;
174+
175+
if (nullptr != desc.platformVer &&
176+
!std::regex_match(
177+
platformVer,
178+
std::regex(std::string(desc.platformVer, desc.platformVerSize))))
179+
continue;
180+
181+
if (nullptr != desc.devName &&
145182
!std::regex_match(
146183
devName, std::regex(std::string(desc.devName, desc.devNameSize))))
147184
continue;
@@ -179,7 +216,7 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
179216

180217
// Filter out devices that are not present in the white list
181218
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get())
182-
filterWhiteList(pi_devices);
219+
filterWhiteList(pi_devices, m_platform);
183220

184221
std::for_each(pi_devices.begin(), pi_devices.end(),
185222
[&res](const RT::PiDevice &a_pi_device) {

sycl/test/config/white_list.cpp

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
// REQUIRES: cpu
22
// RUN: %clangxx -fsycl %s -o %t.out
3-
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf
4-
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out
5-
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out
3+
//
4+
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t1.conf
5+
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t1.conf %t.out
6+
//
7+
// RUN: env PRINT_PLATFORM_INFO=1 %t.out > %t2.conf
8+
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t2.conf %t.out
9+
//
10+
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="PlatformName:{{SUCH NAME DOESN'T EXIST}}" %t.out
611

712
#include <CL/sycl.hpp>
813
#include <iostream>
@@ -12,35 +17,60 @@
1217

1318
using namespace cl;
1419

20+
static void replaceSpecialCharacters(std::string &Str) {
21+
// Replace common special symbols with '.' which matches to any character
22+
std::replace_if(Str.begin(), Str.end(),
23+
[](const char Sym) { return '(' == Sym || ')' == Sym; }, '.');
24+
}
25+
1526
int main() {
1627

28+
// Expected that white list filter is not set
29+
if (getenv("PRINT_PLATFORM_INFO")) {
30+
for (const sycl::platform &Platform : sycl::platform::get_platforms())
31+
if (!Platform.is_host()) {
32+
33+
std::string Name = Platform.get_info<sycl::info::platform::name>();
34+
std::string Ver = Platform.get_info<sycl::info::platform::version>();
35+
// As a string will be used as regexp pattern, we need to get rid of
36+
// symbols that can be treated in a special way.
37+
replaceSpecialCharacters(Name);
38+
replaceSpecialCharacters(Ver);
39+
40+
std::cout << "SYCL_DEVICE_WHITE_LIST=PlatformName:{{" << Name
41+
<< "}},PlatformVersion:{{" << Ver << "}}";
42+
43+
return 0;
44+
}
45+
throw std::runtime_error("Non host device is not found");
46+
}
47+
1748
// Expected that white list filter is not set
1849
if (getenv("PRINT_DEVICE_INFO")) {
19-
for (const sycl::platform &Plt : sycl::platform::get_platforms())
20-
if (!Plt.is_host()) {
21-
const sycl::device Dev = Plt.get_devices().at(0);
22-
std::string DevName = Dev.get_info<sycl::info::device::name>();
23-
const std::string DevVer =
24-
Dev.get_info<sycl::info::device::driver_version>();
25-
// As device name string will be used as regexp pattern, we need to
26-
// get rid of symbols that can be treated in a special way.
27-
// Replace common special symbols with '.' which matches to any sybmol
28-
for (char &Sym : DevName) {
29-
if (')' == Sym || '(' == Sym)
30-
Sym = '.';
31-
}
32-
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName
33-
<< "}},DriverVersion:{{" << DevVer << "}}";
50+
for (const sycl::platform &Platform : sycl::platform::get_platforms())
51+
if (!Platform.is_host()) {
52+
const sycl::device Dev = Platform.get_devices().at(0);
53+
std::string Name = Dev.get_info<sycl::info::device::name>();
54+
std::string Ver = Dev.get_info<sycl::info::device::driver_version>();
55+
56+
// As a string will be used as regexp pattern, we need to get rid of
57+
// symbols that can be treated in a special way.
58+
replaceSpecialCharacters(Name);
59+
replaceSpecialCharacters(Ver);
60+
61+
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << Name
62+
<< "}},DriverVersion:{{" << Ver << "}}";
63+
3464
return 0;
3565
}
3666
throw std::runtime_error("Non host device is not found");
3767
}
3868

3969
// Expected white list to be set with result from "PRINT_DEVICE_INFO" run
4070
if (getenv("TEST_DEVICE_AVAILABLE")) {
41-
for (const sycl::platform &Plt : sycl::platform::get_platforms())
42-
if (!Plt.is_host()) {
43-
if (Plt.get_devices().size() != 1)
71+
for (const sycl::platform &Platform : sycl::platform::get_platforms())
72+
if (!Platform.is_host()) {
73+
if (Platform.get_devices().size() != 1)
4474
throw std::runtime_error("Expected only one non host device.");
4575

4676
return 0;
@@ -50,8 +80,8 @@ int main() {
5080

5181
// Expected white list to be set but empty
5282
if (getenv("TEST_DEVICE_IS_NOT_AVAILABLE")) {
53-
for (const sycl::platform &Plt : sycl::platform::get_platforms())
54-
if (!Plt.is_host())
83+
for (const sycl::platform &Platform : sycl::platform::get_platforms())
84+
if (!Platform.is_host())
5585
throw std::runtime_error("Expected no non host device is available");
5686
return 0;
5787
}

0 commit comments

Comments
 (0)