Skip to content

Commit 4ad5263

Browse files
authored
[SYCL] Add support for devices white list (#867)
Add support for SYCL_DEVICE_WHITE_LIST config which allows to specify the white list of devices and their minimum driver version. This is done by setting configuration in the following format: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}|DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}} Where values in {{ }} can be specified in ECMAScript regular expressions pattern syntax Devices that do not satisfy the pattern from the white list are ignored. Signed-off-by: Vlad Romanov <[email protected]>
1 parent edb985a commit 4ad5263

File tree

4 files changed

+180
-3
lines changed

4 files changed

+180
-3
lines changed

sycl/source/detail/config.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@ void readConfig() {
102102
// Prints configs name with their value
103103
void dumpConfig() {
104104
#define CONFIG(Name, MaxSize, CompileTimeDef) \
105-
const char *Val = SYCLConfig<Name>::get(); \
106-
std::cerr << SYCLConfigBase<Name>::MConfigName << " : " \
107-
<< (Val ? Val : "unset") << std::endl;
105+
{ \
106+
const char *Val = SYCLConfig<Name>::get(); \
107+
std::cerr << SYCLConfigBase<Name>::MConfigName << " : " \
108+
<< (Val ? Val : "unset") << std::endl; \
109+
}
108110
#include "detail/config.def"
109111
#undef CONFIG
110112
}

sycl/source/detail/config.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
// underscore(__).
1212

1313
CONFIG(SYCL_PRINT_EXECUTION_GRAPH, 32, __SYCL_PRINT_EXECUTION_GRAPH)
14+
CONFIG(SYCL_DEVICE_WHITE_LIST, 1024, __SYCL_DEVICE_WHITE_LIST)
1415

sycl/source/detail/platform_impl.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include <CL/sycl/detail/device_impl.hpp>
1010
#include <CL/sycl/detail/platform_impl.hpp>
1111
#include <CL/sycl/device.hpp>
12+
#include <detail/config.hpp>
1213

1314
#include <algorithm>
15+
#include <cstring>
16+
#include <regex>
1417

1518
namespace cl {
1619
namespace sycl {
@@ -49,6 +52,113 @@ platform_impl_host::get_devices(info::device_type dev_type) const {
4952
return res;
5053
}
5154

55+
struct DevDescT {
56+
const char *devName = nullptr;
57+
int devNameSize = 0;
58+
59+
const char *devDriverVer = nullptr;
60+
int devDriverVerSize = 0;
61+
};
62+
63+
static std::vector<DevDescT> getWhiteListDesc() {
64+
const char *str = SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get();
65+
if (!str)
66+
return {};
67+
68+
std::vector<DevDescT> decDescs;
69+
const char devNameStr[] = "DeviceName";
70+
const char driverVerStr[] = "DriverVersion";
71+
decDescs.emplace_back();
72+
while ('\0' != *str) {
73+
const char **valuePtr = nullptr;
74+
int *size = nullptr;
75+
76+
// -1 to avoid comparing null terminator
77+
if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) {
78+
valuePtr = &decDescs.back().devName;
79+
size = &decDescs.back().devNameSize;
80+
str += sizeof(devNameStr) - 1;
81+
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
82+
valuePtr = &decDescs.back().devDriverVer;
83+
size = &decDescs.back().devDriverVerSize;
84+
str += sizeof(driverVerStr) - 1;
85+
}
86+
87+
if (':' != *str)
88+
throw sycl::runtime_error("Malformed device white list");
89+
90+
// Skip ':'
91+
str += 1;
92+
93+
if ('{' != *str || '{' != *(str + 1))
94+
throw sycl::runtime_error("Malformed device white list");
95+
96+
// Skip opening sequence "{{"
97+
str += 2;
98+
99+
*valuePtr = str;
100+
101+
// Increment until closing sequence is encountered
102+
while (('\0' != *str) && ('}' != *str || '}' != *(str + 1)))
103+
++str;
104+
105+
if ('\0' == *str)
106+
throw sycl::runtime_error("Malformed device white list");
107+
108+
*size = str - *valuePtr;
109+
110+
// Skip closing sequence "}}"
111+
str += 2;
112+
113+
if ('\0' == *str)
114+
break;
115+
116+
// '|' means that the is another filter
117+
if ('|' == *str)
118+
decDescs.emplace_back();
119+
else if (',' != *str)
120+
throw sycl::runtime_error("Malformed device white list");
121+
122+
++str;
123+
}
124+
125+
return decDescs;
126+
}
127+
128+
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
129+
const std::vector<DevDescT> whiteList(getWhiteListDesc());
130+
if (whiteList.empty())
131+
return;
132+
133+
int insertIDx = 0;
134+
for (RT::PiDevice dev : pi_devices) {
135+
const string_class devName =
136+
sycl::detail::get_device_info<string_class, info::device::name>::_(dev);
137+
138+
const string_class devDriverVer =
139+
sycl::detail::get_device_info<string_class,
140+
info::device::driver_version>::_(dev);
141+
142+
for (const DevDescT &desc : whiteList) {
143+
// At least device name is required field to consider the filter so far
144+
if (nullptr == desc.devName ||
145+
!std::regex_match(
146+
devName, std::regex(std::string(desc.devName, desc.devNameSize))))
147+
continue;
148+
149+
if (nullptr != desc.devDriverVer &&
150+
!std::regex_match(devDriverVer,
151+
std::regex(std::string(desc.devDriverVer,
152+
desc.devDriverVerSize))))
153+
continue;
154+
155+
pi_devices[insertIDx++] = dev;
156+
break;
157+
}
158+
}
159+
pi_devices.resize(insertIDx);
160+
}
161+
52162
vector_class<device>
53163
platform_impl_pi::get_devices(info::device_type deviceType) const {
54164
vector_class<device> res;
@@ -67,6 +177,10 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
67177
PI_CALL(piDevicesGet)(m_platform, pi::cast<RT::PiDeviceType>(deviceType),
68178
num_devices, pi_devices.data(), nullptr);
69179

180+
// Filter out devices that are not present in the white list
181+
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get())
182+
filterWhiteList(pi_devices);
183+
70184
std::for_each(pi_devices.begin(), pi_devices.end(),
71185
[&res](const RT::PiDevice &a_pi_device) {
72186
device sycl_device = detail::createSyclObjFromImpl<device>(

sycl/test/config/white_list.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// REQUIRES: cpu
2+
// RUN: %clangxx -fsycl %s -o %t.out
3+
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf
4+
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out
5+
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out
6+
7+
#include <CL/sycl.hpp>
8+
#include <iostream>
9+
#include <cstdlib>
10+
#include <exception>
11+
#include <string>
12+
13+
using namespace cl;
14+
15+
int main() {
16+
17+
// Expected that white list filter is not set
18+
if (getenv("PRINT_DEVICE_INFO")) {
19+
for (const sycl::platform &Plt : sycl::platform::get_platforms())
20+
if (!Plt.is_host()) {
21+
const sycl::device Dev = Plt.get_devices().at(0);
22+
std::string DevName = Dev.get_info<sycl::info::device::name>();
23+
const std::string DevVer =
24+
Dev.get_info<sycl::info::device::driver_version>();
25+
// As device name string will be used as regexp pattern, we need to
26+
// get rid of symbols that can be treated in a special way.
27+
// Replace common special symbols with '.' which matches to any sybmol
28+
for (char &Sym : DevName) {
29+
if (')' == Sym || '(' == Sym)
30+
Sym = '.';
31+
}
32+
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName
33+
<< "}},DriverVersion:{{" << DevVer << "}}";
34+
return 0;
35+
}
36+
throw std::runtime_error("Non host device is not found");
37+
}
38+
39+
// Expected white list to be set with result from "PRINT_DEVICE_INFO" run
40+
if (getenv("TEST_DEVICE_AVAILABLE")) {
41+
for (const sycl::platform &Plt : sycl::platform::get_platforms())
42+
if (!Plt.is_host()) {
43+
if (Plt.get_devices().size() != 1)
44+
throw std::runtime_error("Expected only one non host device.");
45+
46+
return 0;
47+
}
48+
throw std::runtime_error("Non host device is not found");
49+
}
50+
51+
// Expected white list to be set but empty
52+
if (getenv("TEST_DEVICE_IS_NOT_AVAILABLE")) {
53+
for (const sycl::platform &Plt : sycl::platform::get_platforms())
54+
if (!Plt.is_host())
55+
throw std::runtime_error("Expected no non host device is available");
56+
return 0;
57+
}
58+
59+
throw std::runtime_error("Unhandled situation");
60+
}

0 commit comments

Comments
 (0)