9
9
#include < CL/sycl/detail/device_impl.hpp>
10
10
#include < CL/sycl/detail/platform_impl.hpp>
11
11
#include < CL/sycl/device.hpp>
12
+ #include < detail/config.hpp>
12
13
13
14
#include < algorithm>
15
+ #include < cstring>
16
+ #include < regex>
14
17
15
18
namespace cl {
16
19
namespace sycl {
@@ -49,6 +52,113 @@ platform_impl_host::get_devices(info::device_type dev_type) const {
49
52
return res;
50
53
}
51
54
55
+ struct DevDescT {
56
+ const char *devName = nullptr ;
57
+ int devNameSize = 0 ;
58
+
59
+ const char *devDriverVer = nullptr ;
60
+ int devDriverVerSize = 0 ;
61
+ };
62
+
63
+ static std::vector<DevDescT> getWhiteListDesc () {
64
+ const char *str = SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get ();
65
+ if (!str)
66
+ return {};
67
+
68
+ std::vector<DevDescT> decDescs;
69
+ const char devNameStr[] = " DeviceName" ;
70
+ const char driverVerStr[] = " DriverVersion" ;
71
+ decDescs.emplace_back ();
72
+ while (' \0 ' != *str) {
73
+ const char **valuePtr = nullptr ;
74
+ int *size = nullptr ;
75
+
76
+ // -1 to avoid comparing null terminator
77
+ if (0 == strncmp (devNameStr, str, sizeof (devNameStr) - 1 )) {
78
+ valuePtr = &decDescs.back ().devName ;
79
+ size = &decDescs.back ().devNameSize ;
80
+ str += sizeof (devNameStr) - 1 ;
81
+ } else if (0 == strncmp (driverVerStr, str, sizeof (driverVerStr) - 1 )) {
82
+ valuePtr = &decDescs.back ().devDriverVer ;
83
+ size = &decDescs.back ().devDriverVerSize ;
84
+ str += sizeof (driverVerStr) - 1 ;
85
+ }
86
+
87
+ if (' :' != *str)
88
+ throw sycl::runtime_error (" Malformed device white list" );
89
+
90
+ // Skip ':'
91
+ str += 1 ;
92
+
93
+ if (' {' != *str || ' {' != *(str + 1 ))
94
+ throw sycl::runtime_error (" Malformed device white list" );
95
+
96
+ // Skip opening sequence "{{"
97
+ str += 2 ;
98
+
99
+ *valuePtr = str;
100
+
101
+ // Increment until closing sequence is encountered
102
+ while ((' \0 ' != *str) && (' }' != *str || ' }' != *(str + 1 )))
103
+ ++str;
104
+
105
+ if (' \0 ' == *str)
106
+ throw sycl::runtime_error (" Malformed device white list" );
107
+
108
+ *size = str - *valuePtr;
109
+
110
+ // Skip closing sequence "}}"
111
+ str += 2 ;
112
+
113
+ if (' \0 ' == *str)
114
+ break ;
115
+
116
+ // '|' means that the is another filter
117
+ if (' |' == *str)
118
+ decDescs.emplace_back ();
119
+ else if (' ,' != *str)
120
+ throw sycl::runtime_error (" Malformed device white list" );
121
+
122
+ ++str;
123
+ }
124
+
125
+ return decDescs;
126
+ }
127
+
128
+ static void filterWhiteList (vector_class<RT::PiDevice> &pi_devices) {
129
+ const std::vector<DevDescT> whiteList (getWhiteListDesc ());
130
+ if (whiteList.empty ())
131
+ return ;
132
+
133
+ int insertIDx = 0 ;
134
+ for (RT::PiDevice dev : pi_devices) {
135
+ const string_class devName =
136
+ sycl::detail::get_device_info<string_class, info::device::name>::_ (dev);
137
+
138
+ const string_class devDriverVer =
139
+ sycl::detail::get_device_info<string_class,
140
+ info::device::driver_version>::_ (dev);
141
+
142
+ for (const DevDescT &desc : whiteList) {
143
+ // At least device name is required field to consider the filter so far
144
+ if (nullptr == desc.devName ||
145
+ !std::regex_match (
146
+ devName, std::regex (std::string (desc.devName , desc.devNameSize ))))
147
+ continue ;
148
+
149
+ if (nullptr != desc.devDriverVer &&
150
+ !std::regex_match (devDriverVer,
151
+ std::regex (std::string (desc.devDriverVer ,
152
+ desc.devDriverVerSize ))))
153
+ continue ;
154
+
155
+ pi_devices[insertIDx++] = dev;
156
+ break ;
157
+ }
158
+ }
159
+ pi_devices.resize (insertIDx);
160
+ }
161
+
52
162
vector_class<device>
53
163
platform_impl_pi::get_devices (info::device_type deviceType) const {
54
164
vector_class<device> res;
@@ -67,6 +177,10 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
67
177
PI_CALL (piDevicesGet)(m_platform, pi::cast<RT::PiDeviceType>(deviceType),
68
178
num_devices, pi_devices.data (), nullptr );
69
179
180
+ // Filter out devices that are not present in the white list
181
+ if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get ())
182
+ filterWhiteList (pi_devices);
183
+
70
184
std::for_each (pi_devices.begin (), pi_devices.end (),
71
185
[&res](const RT::PiDevice &a_pi_device) {
72
186
device sycl_device = detail::createSyclObjFromImpl<device>(
0 commit comments