Skip to content

[SYCL] Add support for platform name and platform version in device w… #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ struct DevDescT {

const char *devDriverVer = nullptr;
int devDriverVerSize = 0;

const char *platformName = nullptr;
int platformNameSize = 0;

const char *platformVer = nullptr;
int platformVerSize = 0;
};

static std::vector<DevDescT> getWhiteListDesc() {
Expand All @@ -68,6 +74,8 @@ static std::vector<DevDescT> getWhiteListDesc() {
std::vector<DevDescT> decDescs;
const char devNameStr[] = "DeviceName";
const char driverVerStr[] = "DriverVersion";
const char platformNameStr[] = "PlatformName";
const char platformVerStr[] = "PlatformVersion";
decDescs.emplace_back();
while ('\0' != *str) {
const char **valuePtr = nullptr;
Expand All @@ -78,6 +86,15 @@ static std::vector<DevDescT> getWhiteListDesc() {
valuePtr = &decDescs.back().devName;
size = &decDescs.back().devNameSize;
str += sizeof(devNameStr) - 1;
} else if (0 ==
strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) {
valuePtr = &decDescs.back().platformName;
size = &decDescs.back().platformNameSize;
str += sizeof(platformNameStr) - 1;
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
valuePtr = &decDescs.back().platformVer;
size = &decDescs.back().platformVerSize;
str += sizeof(platformVerStr) - 1;
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
valuePtr = &decDescs.back().devDriverVer;
size = &decDescs.back().devDriverVerSize;
Expand Down Expand Up @@ -125,23 +142,43 @@ static std::vector<DevDescT> getWhiteListDesc() {
return decDescs;
}

static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices,
RT::PiPlatform pi_platform) {
const std::vector<DevDescT> whiteList(getWhiteListDesc());
if (whiteList.empty())
return;

const string_class platformName =
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
pi_platform);

const string_class platformVer = sycl::detail::get_platform_info<
string_class, info::platform::version>::get(pi_platform);

int insertIDx = 0;
for (RT::PiDevice dev : pi_devices) {
const string_class devName =
sycl::detail::get_device_info<string_class, info::device::name>::get(dev);
sycl::detail::get_device_info<string_class, info::device::name>::get(
dev);

const string_class devDriverVer =
sycl::detail::get_device_info<string_class,
info::device::driver_version>::get(dev);

for (const DevDescT &desc : whiteList) {
// At least device name is required field to consider the filter so far
if (nullptr == desc.devName ||
if (nullptr != desc.platformName &&
!std::regex_match(platformName,
std::regex(std::string(desc.platformName,
desc.platformNameSize))))
continue;

if (nullptr != desc.platformVer &&
!std::regex_match(
platformVer,
std::regex(std::string(desc.platformVer, desc.platformVerSize))))
continue;

if (nullptr != desc.devName &&
!std::regex_match(
devName, std::regex(std::string(desc.devName, desc.devNameSize))))
continue;
Expand Down Expand Up @@ -179,7 +216,7 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {

// Filter out devices that are not present in the white list
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get())
filterWhiteList(pi_devices);
filterWhiteList(pi_devices, m_platform);

std::for_each(pi_devices.begin(), pi_devices.end(),
[&res](const RT::PiDevice &a_pi_device) {
Expand Down
76 changes: 53 additions & 23 deletions sycl/test/config/white_list.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// REQUIRES: cpu
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out
//
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t1.conf
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t1.conf %t.out
//
// RUN: env PRINT_PLATFORM_INFO=1 %t.out > %t2.conf
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t2.conf %t.out
//
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="PlatformName:{{SUCH NAME DOESN'T EXIST}}" %t.out

#include <CL/sycl.hpp>
#include <iostream>
Expand All @@ -12,35 +17,60 @@

using namespace cl;

static void replaceSpecialCharacters(std::string &Str) {
// Replace common special symbols with '.' which matches to any character
std::replace_if(Str.begin(), Str.end(),
[](const char Sym) { return '(' == Sym || ')' == Sym; }, '.');
}

int main() {

// Expected that white list filter is not set
if (getenv("PRINT_PLATFORM_INFO")) {
for (const sycl::platform &Platform : sycl::platform::get_platforms())
if (!Platform.is_host()) {

std::string Name = Platform.get_info<sycl::info::platform::name>();
std::string Ver = Platform.get_info<sycl::info::platform::version>();
// As a string will be used as regexp pattern, we need to get rid of
// symbols that can be treated in a special way.
replaceSpecialCharacters(Name);
replaceSpecialCharacters(Ver);

std::cout << "SYCL_DEVICE_WHITE_LIST=PlatformName:{{" << Name
<< "}},PlatformVersion:{{" << Ver << "}}";

return 0;
}
throw std::runtime_error("Non host device is not found");
}

// Expected that white list filter is not set
if (getenv("PRINT_DEVICE_INFO")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host()) {
const sycl::device Dev = Plt.get_devices().at(0);
std::string DevName = Dev.get_info<sycl::info::device::name>();
const std::string DevVer =
Dev.get_info<sycl::info::device::driver_version>();
// As device name string will be used as regexp pattern, we need to
// get rid of symbols that can be treated in a special way.
// Replace common special symbols with '.' which matches to any sybmol
for (char &Sym : DevName) {
if (')' == Sym || '(' == Sym)
Sym = '.';
}
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName
<< "}},DriverVersion:{{" << DevVer << "}}";
for (const sycl::platform &Platform : sycl::platform::get_platforms())
if (!Platform.is_host()) {
const sycl::device Dev = Platform.get_devices().at(0);
std::string Name = Dev.get_info<sycl::info::device::name>();
std::string Ver = Dev.get_info<sycl::info::device::driver_version>();

// As a string will be used as regexp pattern, we need to get rid of
// symbols that can be treated in a special way.
replaceSpecialCharacters(Name);
replaceSpecialCharacters(Ver);

std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << Name
<< "}},DriverVersion:{{" << Ver << "}}";

return 0;
}
throw std::runtime_error("Non host device is not found");
}

// Expected white list to be set with result from "PRINT_DEVICE_INFO" run
if (getenv("TEST_DEVICE_AVAILABLE")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host()) {
if (Plt.get_devices().size() != 1)
for (const sycl::platform &Platform : sycl::platform::get_platforms())
if (!Platform.is_host()) {
if (Platform.get_devices().size() != 1)
throw std::runtime_error("Expected only one non host device.");

return 0;
Expand All @@ -50,8 +80,8 @@ int main() {

// Expected white list to be set but empty
if (getenv("TEST_DEVICE_IS_NOT_AVAILABLE")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host())
for (const sycl::platform &Platform : sycl::platform::get_platforms())
if (!Platform.is_host())
throw std::runtime_error("Expected no non host device is available");
return 0;
}
Expand Down