Skip to content

[SYCL] Additional support for SYCL_DEVICE_ALLOWLIST #2483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/doc/EnvironmentVariables.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ subject to change. Do not rely on these variables in production code.
| SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP | Any(\*) | Disable cleanup of finished command nodes at host-device synchronization points. |
| SYCL_THROW_ON_BLOCK | Any(\*) | Throw an exception on attempt to wait for a blocked command. |
| SYCL_DEVICELIB_INHIBIT_NATIVE | String of device library extensions (separated by a whitespace) | Do not rely on device native support for devicelib extensions listed in this option. |
| SYCL_DEVICE_ALLOWLIST | A list of devices and their minimum driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. |
| SYCL_DEVICE_ALLOWLIST | A list of devices and their minimum driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. Note that the device name, platform name and their respective versions are regular expression. Special characters, such as parenthesis, must be escaped. |
| SYCL_QUEUE_THREAD_POOL_SIZE | Positive integer | Number of threads in thread pool of queue. |
| SYCL_DEVICELIB_NO_FALLBACK | Any(\*) | Disable loading and linking of device library images |
| SYCL_PI_LEVEL0_MAX_COMMAND_LIST_CACHE | Positive integer | Maximum number of oneAPI Level Zero Command lists that can be allocated with no reuse before throwing an "out of resources" error. Default is 20000, threshold may be increased based on resource availabilty and workload demand. |
Expand Down
285 changes: 196 additions & 89 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <algorithm>
#include <cstring>
#include <regex>
#include <string>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down Expand Up @@ -121,95 +122,172 @@ vector_class<platform> platform_impl::get_platforms() {
}

struct DevDescT {
const char *devName = nullptr;
int devNameSize = 0;
const char *devDriverVer = nullptr;
int devDriverVerSize = 0;

const char *platformName = nullptr;
int platformNameSize = 0;

const char *platformVer = nullptr;
int platformVerSize = 0;
std::string devName;
std::string devDriverVer;
std::string platName;
std::string platVer;
};

static std::vector<DevDescT> getAllowListDesc() {
const char *str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get();
if (!str)
std::string allowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
if (allowList.empty())
return {};

std::string deviceName("DeviceName:");
std::string driverVersion("DriverVersion:");
std::string platformName("PlatformName:");
std::string platformVersion("PlatformVersion:");
std::vector<DevDescT> decDescs;
const char devNameStr[] = "DeviceName";
const char driverVerStr[] = "DriverVersion";
const char platformNameStr[] = "PlatformName";
const char platformVerStr[] = "PlatformVersion";
decDescs.emplace_back();
while ('\0' != *str) {
const char **valuePtr = nullptr;
int *size = nullptr;

// -1 to avoid comparing null terminator
if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) {
valuePtr = &decDescs.back().devName;
size = &decDescs.back().devNameSize;
str += sizeof(devNameStr) - 1;
} else if (0 ==
strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) {
valuePtr = &decDescs.back().platformName;
size = &decDescs.back().platformNameSize;
str += sizeof(platformNameStr) - 1;
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
valuePtr = &decDescs.back().platformVer;
size = &decDescs.back().platformVerSize;
str += sizeof(platformVerStr) - 1;
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
valuePtr = &decDescs.back().devDriverVer;
size = &decDescs.back().devDriverVerSize;
str += sizeof(driverVerStr) - 1;
} else {
throw sycl::runtime_error("Unrecognized key in device allowlist",
PI_INVALID_VALUE);
}

if (':' != *str)
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
size_t pos = 0;
size_t prev = pos;
while (pos < allowList.size()) {
if ((allowList.compare(pos, deviceName.size(), deviceName)) == 0) {
prev = pos;
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
if (pos > prev + deviceName.size()) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}

// Skip ':'
str += 1;
pos = pos + 2;
size_t start = pos;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
decDescs.back().devName = allowList.substr(start, pos - start);
pos = pos + 2;

if (allowList[pos] == ',') {
pos++;
}
}

if ('{' != *str || '{' != *(str + 1))
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
else if ((allowList.compare(pos, driverVersion.size(), driverVersion)) ==
0) {
prev = pos;
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
if (pos > prev + driverVersion.size()) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}

// Skip opening sequence "{{"
str += 2;
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
decDescs.back().devDriverVer = allowList.substr(start, pos - start);
pos = pos + 2;

*valuePtr = str;
if (allowList[pos] == ',') {
pos++;
}
}

// Increment until closing sequence is encountered
while (('\0' != *str) && ('}' != *str || '}' != *(str + 1)))
++str;
else if ((allowList.compare(pos, platformName.size(), platformName)) == 0) {
prev = pos;
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
if (pos > prev + platformName.size()) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}

if ('\0' == *str)
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
decDescs.back().platName = allowList.substr(start, pos - start);
pos = pos + 2;

*size = str - *valuePtr;
if (allowList[pos] == ',') {
pos++;
}

// Skip closing sequence "}}"
str += 2;
}

if ('\0' == *str)
break;
else if ((allowList.compare(pos, platformVersion.size(),
platformVersion)) == 0) {
prev = pos;
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
if (pos > prev + platformVersion.size()) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}

// '|' means that the is another filter
if ('|' == *str)
size_t start = pos + 2;
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
PI_INVALID_VALUE);
}
decDescs.back().platVer = allowList.substr(start, pos - start);
pos = pos + 2;
}

else if (allowList.find('|', pos) != std::string::npos) {
pos = allowList.find('|') + 1;
while (allowList[pos] == ' ') {
pos++;
}
decDescs.emplace_back();
else if (',' != *str)
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
}

++str;
else {
throw sycl::runtime_error("Unrecognized key in device allowlist",
PI_INVALID_VALUE);
}
} // while (pos <= allowList.size())
return decDescs;
}

std::vector<int> convertVersionString(std::string version) {
// version string format is xx.yy.zzzzz
std::vector<int> values;
size_t pos = 0;
size_t start = pos;
if ((pos = version.find(".", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in version string",
PI_INVALID_VALUE);
}
values.push_back(std::stoi(version.substr(start, pos)));
pos++;
start = pos;
if ((pos = version.find(".", pos)) == std::string::npos) {
throw sycl::runtime_error("Malformed syntax in version string",
PI_INVALID_VALUE);
}
values.push_back(std::stoi(version.substr(start, pos)));
pos++;
values.push_back(std::stoi(version.substr(pos)));

return decDescs;
return values;
}

enum MatchState { UNKNOWN, MATCH, NOMATCH };

MatchState matchVersions(std::string version1, std::string version2) {
std::vector<int> v1 = convertVersionString(version1);
std::vector<int> v2 = convertVersionString(version2);
if (v1[0] >= v2[0] && v1[1] >= v2[1] && v1[2] >= v2[2]) {
return MatchState::MATCH;
} else {
return MatchState::NOMATCH;
}
}

static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
Expand All @@ -218,6 +296,11 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
if (AllowList.empty())
return;

MatchState devNameState = UNKNOWN;
MatchState devVerState = UNKNOWN;
MatchState platNameState = UNKNOWN;
MatchState platVerState = UNKNOWN;

const string_class PlatformName =
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
PiPlatform, Plugin);
Expand All @@ -237,33 +320,57 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
string_class, info::device::driver_version>::get(Device, Plugin);

for (const DevDescT &Desc : AllowList) {
if (nullptr != Desc.platformName &&
!std::regex_match(PlatformName,
std::regex(std::string(Desc.platformName,
Desc.platformNameSize))))
continue;

if (nullptr != Desc.platformVer &&
!std::regex_match(
PlatformVer,
std::regex(std::string(Desc.platformVer, Desc.platformVerSize))))
continue;

if (nullptr != Desc.devName &&
!std::regex_match(DeviceName, std::regex(std::string(
Desc.devName, Desc.devNameSize))))
continue;

if (nullptr != Desc.devDriverVer &&
!std::regex_match(DeviceDriverVer,
std::regex(std::string(Desc.devDriverVer,
Desc.devDriverVerSize))))
continue;
if (!Desc.platName.empty()) {
if (!std::regex_match(PlatformName, std::regex(Desc.platName))) {
platNameState = MatchState::NOMATCH;
continue;
} else {
platNameState = MatchState::MATCH;
}
}

if (!Desc.platVer.empty()) {
if (!std::regex_match(PlatformVer, std::regex(Desc.platVer))) {
platVerState = MatchState::NOMATCH;
continue;
} else {
platVerState = MatchState::MATCH;
}
}

if (!Desc.devName.empty()) {
if (!std::regex_match(DeviceName, std::regex(Desc.devName))) {
devNameState = MatchState::NOMATCH;
continue;
} else {
devNameState = MatchState::MATCH;
}
}

if (!Desc.devDriverVer.empty()) {
if (!std::regex_match(DeviceDriverVer, std::regex(Desc.devDriverVer))) {
devVerState = matchVersions(DeviceDriverVer, Desc.devDriverVer);
if (devVerState == MatchState::NOMATCH) {
continue;
}
} else {
devVerState = MatchState::MATCH;
}
}

PiDevices[InsertIDx++] = Device;
break;
}
}
if (devNameState == MatchState::MATCH && devVerState == MatchState::NOMATCH) {
throw sycl::runtime_error("Requested SYCL device not found",
PI_DEVICE_NOT_FOUND);
}
if (platNameState == MatchState::MATCH &&
platVerState == MatchState::NOMATCH) {
throw sycl::runtime_error("Requested SYCL platform not found",
PI_DEVICE_NOT_FOUND);
}
PiDevices.resize(InsertIDx);
}

Expand Down
Loading