Skip to content

Commit dc5c005

Browse files
committed
[SYCL] Additional support for SYCL_DEVICE_ALLOWLIST
Added support for the case where multiple devices or platforms are listed in SYCL_DEVICE_ALLOWLIST. Also fixed a memory issue. Created a test case which tests both legal and illegal uses. Updated the documentation. Signed-off-by: Gail Lyons <[email protected]>
1 parent 5c30ab7 commit dc5c005

File tree

3 files changed

+944
-192
lines changed

3 files changed

+944
-192
lines changed

sycl/doc/EnvironmentVariables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ subject to change. Do not rely on these variables in production code.
2222
| SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP | Any(\*) | Disable cleanup of finished command nodes at host-device synchronization points. |
2323
| SYCL_THROW_ON_BLOCK | Any(\*) | Throw an exception on attempt to wait for a blocked command. |
2424
| SYCL_DEVICELIB_INHIBIT_NATIVE | String of device library extensions (separated by a whitespace) | Do not rely on device native support for devicelib extensions listed in this option. |
25-
| SYCL_DEVICE_ALLOWLIST | A list of devices and their minimum driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. |
25+
| SYCL_DEVICE_ALLOWLIST | A list of devices and their minimum driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. Note that the device name, platform name and their respective versions are regular expression. Special characters, such as parenthesis, must be escaped. |
2626
| SYCL_QUEUE_THREAD_POOL_SIZE | Positive integer | Number of threads in thread pool of queue. |
2727
| SYCL_DEVICELIB_NO_FALLBACK | Any(\*) | Disable loading and linking of device library images |
2828
| SYCL_PI_LEVEL0_MAX_COMMAND_LIST_CACHE | Positive integer | Maximum number of oneAPI Level Zero Command lists that can be allocated with no reuse before throwing an "out of resources" error. Default is 20000, threshold may be increased based on resource availabilty and workload demand. |

sycl/source/detail/platform_impl.cpp

Lines changed: 195 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <cstring>
1818
#include <regex>
19+
#include <string>
1920

2021
__SYCL_INLINE_NAMESPACE(cl) {
2122
namespace sycl {
@@ -121,112 +122,172 @@ vector_class<platform> platform_impl::get_platforms() {
121122
}
122123

123124
struct DevDescT {
124-
const char *devName = nullptr;
125-
int devNameSize = 0;
126-
const char *devDriverVer = nullptr;
127-
int devDriverVerSize = 0;
128-
129-
const char *platformName = nullptr;
130-
int platformNameSize = 0;
131-
132-
const char *platformVer = nullptr;
133-
int platformVerSize = 0;
125+
std::string devName;
126+
std::string devDriverVer;
127+
std::string platName;
128+
std::string platVer;
134129
};
135130

136131
static std::vector<DevDescT> getAllowListDesc() {
137-
const char *Str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get();
138-
if (!Str)
132+
std::string allowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
133+
if (allowList.empty())
139134
return {};
140135

136+
std::string deviceName("DeviceName:");
137+
std::string driverVersion("DriverVersion:");
138+
std::string platformName("PlatformName:");
139+
std::string platformVersion("PlatformVersion:");
141140
std::vector<DevDescT> decDescs;
142-
const char devNameStr[] = "DeviceName";
143-
const char driverVerStr[] = "DriverVersion";
144-
const char platformNameStr[] = "PlatformName";
145-
const char platformVerStr[] = "PlatformVersion";
146141
decDescs.emplace_back();
147142

148-
std::cout << "Before: " << Str << std::endl;
149-
150-
// Replace common special symbols with '.' which matches to any character
151-
#if 0 // gail
152-
std::string tmp(Str);
153-
std::replace_if(tmp.begin(), tmp.end(),
154-
[](const char sym) { return '(' == sym || ')' == sym; }, '.');
155-
const char * str = tmp.c_str();
156-
#endif //gail
157-
158-
std::string tmp(Str);
159-
std::replace(tmp.begin(), tmp.end(), '(', '.');
160-
std::replace(tmp.begin(), tmp.end(), ')', '.');
161-
const char * str = tmp.c_str();
162-
std::cout << "After : " << str << std::endl;
163-
164-
while ('\0' != *str) {
165-
const char **valuePtr = nullptr;
166-
int *size = nullptr;
167-
168-
// -1 to avoid comparing null terminator
169-
if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) {
170-
valuePtr = &decDescs.back().devName;
171-
size = &decDescs.back().devNameSize;
172-
str += sizeof(devNameStr) - 1;
173-
} else if (0 ==
174-
strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) {
175-
valuePtr = &decDescs.back().platformName;
176-
size = &decDescs.back().platformNameSize;
177-
str += sizeof(platformNameStr) - 1;
178-
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
179-
valuePtr = &decDescs.back().platformVer;
180-
size = &decDescs.back().platformVerSize;
181-
str += sizeof(platformVerStr) - 1;
182-
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
183-
valuePtr = &decDescs.back().devDriverVer;
184-
size = &decDescs.back().devDriverVerSize;
185-
str += sizeof(driverVerStr) - 1;
186-
} else {
187-
throw sycl::runtime_error("Unrecognized key in device allowlist",
188-
PI_INVALID_VALUE);
189-
}
143+
size_t pos = 0;
144+
size_t prev = pos;
145+
while (pos < allowList.size()) {
146+
if ((allowList.compare(pos, deviceName.size(), deviceName)) == 0) {
147+
prev = pos;
148+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
149+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
150+
PI_INVALID_VALUE);
151+
}
152+
if (pos > prev + deviceName.size()) {
153+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
154+
PI_INVALID_VALUE);
155+
}
190156

191-
if (':' != *str)
192-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
157+
pos = pos + 2;
158+
size_t start = pos;
159+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
160+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
161+
PI_INVALID_VALUE);
162+
}
163+
decDescs.back().devName = allowList.substr(start, pos - start);
164+
pos = pos + 2;
193165

194-
// Skip ':'
195-
str += 1;
166+
if (allowList[pos] == ',') {
167+
pos++;
168+
}
169+
}
196170

197-
if ('{' != *str || '{' != *(str + 1))
198-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
171+
else if ((allowList.compare(pos, driverVersion.size(), driverVersion)) ==
172+
0) {
173+
prev = pos;
174+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
175+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
176+
PI_INVALID_VALUE);
177+
}
178+
if (pos > prev + driverVersion.size()) {
179+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
180+
PI_INVALID_VALUE);
181+
}
199182

200-
// Skip opening sequence "{{"
201-
str += 2;
183+
size_t start = pos + 2;
184+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
185+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
186+
PI_INVALID_VALUE);
187+
}
188+
decDescs.back().devDriverVer = allowList.substr(start, pos - start);
189+
pos = pos + 2;
202190

203-
*valuePtr = str;
191+
if (allowList[pos] == ',') {
192+
pos++;
193+
}
194+
}
195+
196+
else if ((allowList.compare(pos, platformName.size(), platformName)) == 0) {
197+
prev = pos;
198+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
199+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
200+
PI_INVALID_VALUE);
201+
}
202+
if (pos > prev + platformName.size()) {
203+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
204+
PI_INVALID_VALUE);
205+
}
204206

205-
// Increment until closing sequence is encountered
206-
while (('\0' != *str) && ('}' != *str || '}' != *(str + 1)))
207-
++str;
207+
size_t start = pos + 2;
208+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
209+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
210+
PI_INVALID_VALUE);
211+
}
212+
decDescs.back().platName = allowList.substr(start, pos - start);
213+
pos = pos + 2;
208214

209-
if ('\0' == *str)
210-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
215+
if (allowList[pos] == ',') {
216+
pos++;
217+
}
211218

212-
*size = str - *valuePtr;
219+
}
213220

214-
// Skip closing sequence "}}"
215-
str += 2;
221+
else if ((allowList.compare(pos, platformVersion.size(),
222+
platformVersion)) == 0) {
223+
prev = pos;
224+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
225+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
226+
PI_INVALID_VALUE);
227+
}
228+
if (pos > prev + platformVersion.size()) {
229+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
230+
PI_INVALID_VALUE);
231+
}
216232

217-
if ('\0' == *str)
218-
break;
233+
size_t start = pos + 2;
234+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
235+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
236+
PI_INVALID_VALUE);
237+
}
238+
decDescs.back().platVer = allowList.substr(start, pos - start);
239+
pos = pos + 2;
240+
}
219241

220-
// '|' means that the is another filter
221-
if ('|' == *str)
242+
else if (allowList.find('|', pos) != std::string::npos) {
243+
pos = allowList.find('|') + 1;
244+
while (allowList[pos] == ' ') {
245+
pos++;
246+
}
222247
decDescs.emplace_back();
223-
else if (',' != *str)
224-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
248+
}
225249

226-
++str;
250+
else {
251+
throw sycl::runtime_error("Unrecognized key in device allowlist",
252+
PI_INVALID_VALUE);
253+
}
254+
} // while (pos <= allowList.size())
255+
return decDescs;
256+
}
257+
258+
std::vector<int> convertVersionString(std::string version) {
259+
// version string format is xx.yy.zzzzz
260+
std::vector<int> values;
261+
size_t pos = 0;
262+
size_t start = pos;
263+
if ((pos = version.find(".", pos)) == std::string::npos) {
264+
throw sycl::runtime_error("Malformed syntax in version string",
265+
PI_INVALID_VALUE);
266+
}
267+
values.push_back(std::stoi(version.substr(start, pos)));
268+
pos++;
269+
start = pos;
270+
if ((pos = version.find(".", pos)) == std::string::npos) {
271+
throw sycl::runtime_error("Malformed syntax in version string",
272+
PI_INVALID_VALUE);
227273
}
274+
values.push_back(std::stoi(version.substr(start, pos)));
275+
pos++;
276+
values.push_back(std::stoi(version.substr(pos)));
228277

229-
return decDescs;
278+
return values;
279+
}
280+
281+
enum MatchState { UNKNOWN, MATCH, NOMATCH };
282+
283+
MatchState matchVersions(std::string version1, std::string version2) {
284+
std::vector<int> v1 = convertVersionString(version1);
285+
std::vector<int> v2 = convertVersionString(version2);
286+
if (v1[0] >= v2[0] && v1[1] >= v2[1] && v1[2] >= v2[2]) {
287+
return MatchState::MATCH;
288+
} else {
289+
return MatchState::NOMATCH;
290+
}
230291
}
231292

232293
static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
@@ -235,6 +296,11 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
235296
if (AllowList.empty())
236297
return;
237298

299+
MatchState devNameState = UNKNOWN;
300+
MatchState devVerState = UNKNOWN;
301+
MatchState platNameState = UNKNOWN;
302+
MatchState platVerState = UNKNOWN;
303+
238304
const string_class PlatformName =
239305
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
240306
PiPlatform, Plugin);
@@ -254,33 +320,57 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
254320
string_class, info::device::driver_version>::get(Device, Plugin);
255321

256322
for (const DevDescT &Desc : AllowList) {
257-
if (nullptr != Desc.platformName &&
258-
!std::regex_match(PlatformName,
259-
std::regex(std::string(Desc.platformName,
260-
Desc.platformNameSize))))
261-
continue;
262-
263-
if (nullptr != Desc.platformVer &&
264-
!std::regex_match(
265-
PlatformVer,
266-
std::regex(std::string(Desc.platformVer, Desc.platformVerSize))))
267-
continue;
268-
269-
if (nullptr != Desc.devName &&
270-
!std::regex_match(DeviceName, std::regex(std::string(
271-
Desc.devName, Desc.devNameSize))))
272-
continue;
273-
274-
if (nullptr != Desc.devDriverVer &&
275-
!std::regex_match(DeviceDriverVer,
276-
std::regex(std::string(Desc.devDriverVer,
277-
Desc.devDriverVerSize))))
278-
continue;
323+
if (!Desc.platName.empty()) {
324+
if (!std::regex_match(PlatformName, std::regex(Desc.platName))) {
325+
platNameState = MatchState::NOMATCH;
326+
continue;
327+
} else {
328+
platNameState = MatchState::MATCH;
329+
}
330+
}
331+
332+
if (!Desc.platVer.empty()) {
333+
if (!std::regex_match(PlatformVer, std::regex(Desc.platVer))) {
334+
platVerState = MatchState::NOMATCH;
335+
continue;
336+
} else {
337+
platVerState = MatchState::MATCH;
338+
}
339+
}
340+
341+
if (!Desc.devName.empty()) {
342+
if (!std::regex_match(DeviceName, std::regex(Desc.devName))) {
343+
devNameState = MatchState::NOMATCH;
344+
continue;
345+
} else {
346+
devNameState = MatchState::MATCH;
347+
}
348+
}
349+
350+
if (!Desc.devDriverVer.empty()) {
351+
if (!std::regex_match(DeviceDriverVer, std::regex(Desc.devDriverVer))) {
352+
devVerState = matchVersions(DeviceDriverVer, Desc.devDriverVer);
353+
if (devVerState == MatchState::NOMATCH) {
354+
continue;
355+
}
356+
} else {
357+
devVerState = MatchState::MATCH;
358+
}
359+
}
279360

280361
PiDevices[InsertIDx++] = Device;
281362
break;
282363
}
283364
}
365+
if (devNameState == MatchState::MATCH && devVerState == MatchState::NOMATCH) {
366+
throw sycl::runtime_error("Requested SYCL device not found",
367+
PI_DEVICE_NOT_FOUND);
368+
}
369+
if (platNameState == MatchState::MATCH &&
370+
platVerState == MatchState::NOMATCH) {
371+
throw sycl::runtime_error("Requested SYCL platform not found",
372+
PI_DEVICE_NOT_FOUND);
373+
}
284374
PiDevices.resize(InsertIDx);
285375
}
286376

0 commit comments

Comments
 (0)