16
16
#include < algorithm>
17
17
#include < cstring>
18
18
#include < regex>
19
+ #include < string>
19
20
20
21
__SYCL_INLINE_NAMESPACE (cl) {
21
22
namespace sycl {
@@ -121,112 +122,172 @@ vector_class<platform> platform_impl::get_platforms() {
121
122
}
122
123
123
124
struct DevDescT {
124
- const char *devName = nullptr ;
125
- int devNameSize = 0 ;
126
- const char *devDriverVer = nullptr ;
127
- int devDriverVerSize = 0 ;
128
-
129
- const char *platformName = nullptr ;
130
- int platformNameSize = 0 ;
131
-
132
- const char *platformVer = nullptr ;
133
- int platformVerSize = 0 ;
125
+ std::string devName;
126
+ std::string devDriverVer;
127
+ std::string platName;
128
+ std::string platVer;
134
129
};
135
130
136
131
static std::vector<DevDescT> getAllowListDesc () {
137
- const char *Str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get ();
138
- if (!Str )
132
+ std::string allowList ( SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get () );
133
+ if (allowList. empty () )
139
134
return {};
140
135
136
+ std::string deviceName (" DeviceName:" );
137
+ std::string driverVersion (" DriverVersion:" );
138
+ std::string platformName (" PlatformName:" );
139
+ std::string platformVersion (" PlatformVersion:" );
141
140
std::vector<DevDescT> decDescs;
142
- const char devNameStr[] = " DeviceName" ;
143
- const char driverVerStr[] = " DriverVersion" ;
144
- const char platformNameStr[] = " PlatformName" ;
145
- const char platformVerStr[] = " PlatformVersion" ;
146
141
decDescs.emplace_back ();
147
142
148
- std::cout << " Before: " << Str << std::endl;
149
-
150
- // Replace common special symbols with '.' which matches to any character
151
- #if 0 // gail
152
- std::string tmp(Str);
153
- std::replace_if(tmp.begin(), tmp.end(),
154
- [](const char sym) { return '(' == sym || ')' == sym; }, '.');
155
- const char * str = tmp.c_str();
156
- #endif // gail
157
-
158
- std::string tmp (Str);
159
- std::replace (tmp.begin (), tmp.end (), ' (' , ' .' );
160
- std::replace (tmp.begin (), tmp.end (), ' )' , ' .' );
161
- const char * str = tmp.c_str ();
162
- std::cout << " After : " << str << std::endl;
163
-
164
- while (' \0 ' != *str) {
165
- const char **valuePtr = nullptr ;
166
- int *size = nullptr ;
167
-
168
- // -1 to avoid comparing null terminator
169
- if (0 == strncmp (devNameStr, str, sizeof (devNameStr) - 1 )) {
170
- valuePtr = &decDescs.back ().devName ;
171
- size = &decDescs.back ().devNameSize ;
172
- str += sizeof (devNameStr) - 1 ;
173
- } else if (0 ==
174
- strncmp (platformNameStr, str, sizeof (platformNameStr) - 1 )) {
175
- valuePtr = &decDescs.back ().platformName ;
176
- size = &decDescs.back ().platformNameSize ;
177
- str += sizeof (platformNameStr) - 1 ;
178
- } else if (0 == strncmp (platformVerStr, str, sizeof (platformVerStr) - 1 )) {
179
- valuePtr = &decDescs.back ().platformVer ;
180
- size = &decDescs.back ().platformVerSize ;
181
- str += sizeof (platformVerStr) - 1 ;
182
- } else if (0 == strncmp (driverVerStr, str, sizeof (driverVerStr) - 1 )) {
183
- valuePtr = &decDescs.back ().devDriverVer ;
184
- size = &decDescs.back ().devDriverVerSize ;
185
- str += sizeof (driverVerStr) - 1 ;
186
- } else {
187
- throw sycl::runtime_error (" Unrecognized key in device allowlist" ,
188
- PI_INVALID_VALUE);
189
- }
143
+ size_t pos = 0 ;
144
+ size_t prev = pos;
145
+ while (pos < allowList.size ()) {
146
+ if ((allowList.compare (pos, deviceName.size (), deviceName)) == 0 ) {
147
+ prev = pos;
148
+ if ((pos = allowList.find (" {{" , pos)) == std::string::npos) {
149
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
150
+ PI_INVALID_VALUE);
151
+ }
152
+ if (pos > prev + deviceName.size ()) {
153
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
154
+ PI_INVALID_VALUE);
155
+ }
190
156
191
- if (' :' != *str)
192
- throw sycl::runtime_error (" Malformed device allowlist" , PI_INVALID_VALUE);
157
+ pos = pos + 2 ;
158
+ size_t start = pos;
159
+ if ((pos = allowList.find (" }}" , pos)) == std::string::npos) {
160
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
161
+ PI_INVALID_VALUE);
162
+ }
163
+ decDescs.back ().devName = allowList.substr (start, pos - start);
164
+ pos = pos + 2 ;
193
165
194
- // Skip ':'
195
- str += 1 ;
166
+ if (allowList[pos] == ' ,' ) {
167
+ pos++;
168
+ }
169
+ }
196
170
197
- if (' {' != *str || ' {' != *(str + 1 ))
198
- throw sycl::runtime_error (" Malformed device allowlist" , PI_INVALID_VALUE);
171
+ else if ((allowList.compare (pos, driverVersion.size (), driverVersion)) ==
172
+ 0 ) {
173
+ prev = pos;
174
+ if ((pos = allowList.find (" {{" , pos)) == std::string::npos) {
175
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
176
+ PI_INVALID_VALUE);
177
+ }
178
+ if (pos > prev + driverVersion.size ()) {
179
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
180
+ PI_INVALID_VALUE);
181
+ }
199
182
200
- // Skip opening sequence "{{"
201
- str += 2 ;
183
+ size_t start = pos + 2 ;
184
+ if ((pos = allowList.find (" }}" , pos)) == std::string::npos) {
185
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
186
+ PI_INVALID_VALUE);
187
+ }
188
+ decDescs.back ().devDriverVer = allowList.substr (start, pos - start);
189
+ pos = pos + 2 ;
202
190
203
- *valuePtr = str;
191
+ if (allowList[pos] == ' ,' ) {
192
+ pos++;
193
+ }
194
+ }
195
+
196
+ else if ((allowList.compare (pos, platformName.size (), platformName)) == 0 ) {
197
+ prev = pos;
198
+ if ((pos = allowList.find (" {{" , pos)) == std::string::npos) {
199
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
200
+ PI_INVALID_VALUE);
201
+ }
202
+ if (pos > prev + platformName.size ()) {
203
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
204
+ PI_INVALID_VALUE);
205
+ }
204
206
205
- // Increment until closing sequence is encountered
206
- while ((' \0 ' != *str) && (' }' != *str || ' }' != *(str + 1 )))
207
- ++str;
207
+ size_t start = pos + 2 ;
208
+ if ((pos = allowList.find (" }}" , pos)) == std::string::npos) {
209
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
210
+ PI_INVALID_VALUE);
211
+ }
212
+ decDescs.back ().platName = allowList.substr (start, pos - start);
213
+ pos = pos + 2 ;
208
214
209
- if (' \0 ' == *str)
210
- throw sycl::runtime_error (" Malformed device allowlist" , PI_INVALID_VALUE);
215
+ if (allowList[pos] == ' ,' ) {
216
+ pos++;
217
+ }
211
218
212
- *size = str - *valuePtr;
219
+ }
213
220
214
- // Skip closing sequence "}}"
215
- str += 2 ;
221
+ else if ((allowList.compare (pos, platformVersion.size (),
222
+ platformVersion)) == 0 ) {
223
+ prev = pos;
224
+ if ((pos = allowList.find (" {{" , pos)) == std::string::npos) {
225
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
226
+ PI_INVALID_VALUE);
227
+ }
228
+ if (pos > prev + platformVersion.size ()) {
229
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
230
+ PI_INVALID_VALUE);
231
+ }
216
232
217
- if (' \0 ' == *str)
218
- break ;
233
+ size_t start = pos + 2 ;
234
+ if ((pos = allowList.find (" }}" , pos)) == std::string::npos) {
235
+ throw sycl::runtime_error (" Malformed syntax in SYCL_DEVICE_ALLOWLIST" ,
236
+ PI_INVALID_VALUE);
237
+ }
238
+ decDescs.back ().platVer = allowList.substr (start, pos - start);
239
+ pos = pos + 2 ;
240
+ }
219
241
220
- // '|' means that the is another filter
221
- if (' |' == *str)
242
+ else if (allowList.find (' |' , pos) != std::string::npos) {
243
+ pos = allowList.find (' |' ) + 1 ;
244
+ while (allowList[pos] == ' ' ) {
245
+ pos++;
246
+ }
222
247
decDescs.emplace_back ();
223
- else if (' ,' != *str)
224
- throw sycl::runtime_error (" Malformed device allowlist" , PI_INVALID_VALUE);
248
+ }
225
249
226
- ++str;
250
+ else {
251
+ throw sycl::runtime_error (" Unrecognized key in device allowlist" ,
252
+ PI_INVALID_VALUE);
253
+ }
254
+ } // while (pos <= allowList.size())
255
+ return decDescs;
256
+ }
257
+
258
+ std::vector<int > convertVersionString (std::string version) {
259
+ // version string format is xx.yy.zzzzz
260
+ std::vector<int > values;
261
+ size_t pos = 0 ;
262
+ size_t start = pos;
263
+ if ((pos = version.find (" ." , pos)) == std::string::npos) {
264
+ throw sycl::runtime_error (" Malformed syntax in version string" ,
265
+ PI_INVALID_VALUE);
266
+ }
267
+ values.push_back (std::stoi (version.substr (start, pos)));
268
+ pos++;
269
+ start = pos;
270
+ if ((pos = version.find (" ." , pos)) == std::string::npos) {
271
+ throw sycl::runtime_error (" Malformed syntax in version string" ,
272
+ PI_INVALID_VALUE);
227
273
}
274
+ values.push_back (std::stoi (version.substr (start, pos)));
275
+ pos++;
276
+ values.push_back (std::stoi (version.substr (pos)));
228
277
229
- return decDescs;
278
+ return values;
279
+ }
280
+
281
+ enum MatchState { UNKNOWN, MATCH, NOMATCH };
282
+
283
+ MatchState matchVersions (std::string version1, std::string version2) {
284
+ std::vector<int > v1 = convertVersionString (version1);
285
+ std::vector<int > v2 = convertVersionString (version2);
286
+ if (v1[0 ] >= v2[0 ] && v1[1 ] >= v2[1 ] && v1[2 ] >= v2[2 ]) {
287
+ return MatchState::MATCH;
288
+ } else {
289
+ return MatchState::NOMATCH;
290
+ }
230
291
}
231
292
232
293
static void filterAllowList (vector_class<RT::PiDevice> &PiDevices,
@@ -235,6 +296,11 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
235
296
if (AllowList.empty ())
236
297
return ;
237
298
299
+ MatchState devNameState = UNKNOWN;
300
+ MatchState devVerState = UNKNOWN;
301
+ MatchState platNameState = UNKNOWN;
302
+ MatchState platVerState = UNKNOWN;
303
+
238
304
const string_class PlatformName =
239
305
sycl::detail::get_platform_info<string_class, info::platform::name>::get (
240
306
PiPlatform, Plugin);
@@ -254,33 +320,57 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
254
320
string_class, info::device::driver_version>::get (Device, Plugin);
255
321
256
322
for (const DevDescT &Desc : AllowList) {
257
- if (nullptr != Desc.platformName &&
258
- !std::regex_match (PlatformName,
259
- std::regex (std::string (Desc.platformName ,
260
- Desc.platformNameSize ))))
261
- continue ;
262
-
263
- if (nullptr != Desc.platformVer &&
264
- !std::regex_match (
265
- PlatformVer,
266
- std::regex (std::string (Desc.platformVer , Desc.platformVerSize ))))
267
- continue ;
268
-
269
- if (nullptr != Desc.devName &&
270
- !std::regex_match (DeviceName, std::regex (std::string (
271
- Desc.devName , Desc.devNameSize ))))
272
- continue ;
273
-
274
- if (nullptr != Desc.devDriverVer &&
275
- !std::regex_match (DeviceDriverVer,
276
- std::regex (std::string (Desc.devDriverVer ,
277
- Desc.devDriverVerSize ))))
278
- continue ;
323
+ if (!Desc.platName .empty ()) {
324
+ if (!std::regex_match (PlatformName, std::regex (Desc.platName ))) {
325
+ platNameState = MatchState::NOMATCH;
326
+ continue ;
327
+ } else {
328
+ platNameState = MatchState::MATCH;
329
+ }
330
+ }
331
+
332
+ if (!Desc.platVer .empty ()) {
333
+ if (!std::regex_match (PlatformVer, std::regex (Desc.platVer ))) {
334
+ platVerState = MatchState::NOMATCH;
335
+ continue ;
336
+ } else {
337
+ platVerState = MatchState::MATCH;
338
+ }
339
+ }
340
+
341
+ if (!Desc.devName .empty ()) {
342
+ if (!std::regex_match (DeviceName, std::regex (Desc.devName ))) {
343
+ devNameState = MatchState::NOMATCH;
344
+ continue ;
345
+ } else {
346
+ devNameState = MatchState::MATCH;
347
+ }
348
+ }
349
+
350
+ if (!Desc.devDriverVer .empty ()) {
351
+ if (!std::regex_match (DeviceDriverVer, std::regex (Desc.devDriverVer ))) {
352
+ devVerState = matchVersions (DeviceDriverVer, Desc.devDriverVer );
353
+ if (devVerState == MatchState::NOMATCH) {
354
+ continue ;
355
+ }
356
+ } else {
357
+ devVerState = MatchState::MATCH;
358
+ }
359
+ }
279
360
280
361
PiDevices[InsertIDx++] = Device;
281
362
break ;
282
363
}
283
364
}
365
+ if (devNameState == MatchState::MATCH && devVerState == MatchState::NOMATCH) {
366
+ throw sycl::runtime_error (" Requested SYCL device not found" ,
367
+ PI_DEVICE_NOT_FOUND);
368
+ }
369
+ if (platNameState == MatchState::MATCH &&
370
+ platVerState == MatchState::NOMATCH) {
371
+ throw sycl::runtime_error (" Requested SYCL platform not found" ,
372
+ PI_DEVICE_NOT_FOUND);
373
+ }
284
374
PiDevices.resize (InsertIDx);
285
375
}
286
376
0 commit comments