Skip to content

[SYCL] Add support for devices white list #867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sycl/source/detail/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ void readConfig() {
// Prints configs name with their value
void dumpConfig() {
#define CONFIG(Name, MaxSize, CompileTimeDef) \
const char *Val = SYCLConfig<Name>::get(); \
std::cerr << SYCLConfigBase<Name>::MConfigName << " : " \
<< (Val ? Val : "unset") << std::endl;
{ \
const char *Val = SYCLConfig<Name>::get(); \
std::cerr << SYCLConfigBase<Name>::MConfigName << " : " \
<< (Val ? Val : "unset") << std::endl; \
}
#include "detail/config.def"
#undef CONFIG
}
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/config.def
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
// underscore(__).

CONFIG(SYCL_PRINT_EXECUTION_GRAPH, 32, __SYCL_PRINT_EXECUTION_GRAPH)
CONFIG(SYCL_DEVICE_WHITE_LIST, 1024, __SYCL_DEVICE_WHITE_LIST)

114 changes: 114 additions & 0 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include <CL/sycl/detail/device_impl.hpp>
#include <CL/sycl/detail/platform_impl.hpp>
#include <CL/sycl/device.hpp>
#include <detail/config.hpp>

#include <algorithm>
#include <cstring>
#include <regex>

namespace cl {
namespace sycl {
Expand Down Expand Up @@ -49,6 +52,113 @@ platform_impl_host::get_devices(info::device_type dev_type) const {
return res;
}

struct DevDescT {
const char *devName = nullptr;
int devNameSize = 0;

const char *devDriverVer = nullptr;
int devDriverVerSize = 0;
};

static std::vector<DevDescT> getWhiteListDesc() {
const char *str = SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get();
if (!str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to check this condition as early as possible i.e. before calling filterWhiteList to avoid unnecessary overhead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added additional check before calling filterWhiteList

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to check this condition as early as possible i.e. before calling filterWhiteList to avoid unnecessary overhead.

This is already the first operation in filterWhiteList, which overhead you are trying to avoid?

Copy link
Contributor

@bader bader Nov 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector creation/deletion + function calls.

return {};

std::vector<DevDescT> decDescs;
const char devNameStr[] = "DeviceName";
const char driverVerStr[] = "DriverVersion";
decDescs.emplace_back();
while ('\0' != *str) {
const char **valuePtr = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It smells badly in a C++ code...

int *size = nullptr;

// -1 to avoid comparing null terminator
if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) {
valuePtr = &decDescs.back().devName;
size = &decDescs.back().devNameSize;
str += sizeof(devNameStr) - 1;
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
valuePtr = &decDescs.back().devDriverVer;
size = &decDescs.back().devDriverVerSize;
str += sizeof(driverVerStr) - 1;
}

if (':' != *str)
throw sycl::runtime_error("Malformed device white list");

// Skip ':'
str += 1;

if ('{' != *str || '{' != *(str + 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check that *str != '\0'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is actually not needed here.

throw sycl::runtime_error("Malformed device white list");

// Skip opening sequence "{{"
str += 2;

*valuePtr = str;

// Increment until closing sequence is encountered
while (('\0' != *str) && ('}' != *str || '}' != *(str + 1)))
++str;

if ('\0' == *str)
throw sycl::runtime_error("Malformed device white list");

*size = str - *valuePtr;

// Skip closing sequence "}}"
str += 2;

if ('\0' == *str)
break;

// '|' means that the is another filter
if ('|' == *str)
decDescs.emplace_back();
else if (',' != *str)
throw sycl::runtime_error("Malformed device white list");

++str;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to rewrite this loop in something higher-level so it is more understandable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try.


return decDescs;
}

static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
const std::vector<DevDescT> whiteList(getWhiteListDesc());
if (whiteList.empty())
return;

int insertIDx = 0;
for (RT::PiDevice dev : pi_devices) {
const string_class devName =
sycl::detail::get_device_info<string_class, info::device::name>::_(dev);

const string_class devDriverVer =
sycl::detail::get_device_info<string_class,
info::device::driver_version>::_(dev);

for (const DevDescT &desc : whiteList) {
// At least device name is required field to consider the filter so far
if (nullptr == desc.devName ||
!std::regex_match(
devName, std::regex(std::string(desc.devName, desc.devNameSize))))
continue;

if (nullptr != desc.devDriverVer &&
!std::regex_match(devDriverVer,
std::regex(std::string(desc.devDriverVer,
desc.devDriverVerSize))))
continue;

pi_devices[insertIDx++] = dev;
break;
}
}
pi_devices.resize(insertIDx);
}

vector_class<device>
platform_impl_pi::get_devices(info::device_type deviceType) const {
vector_class<device> res;
Expand All @@ -67,6 +177,10 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
PI_CALL(piDevicesGet)(m_platform, pi::cast<RT::PiDeviceType>(deviceType),
num_devices, pi_devices.data(), nullptr);

// Filter out devices that are not present in the white list
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get())
filterWhiteList(pi_devices);

std::for_each(pi_devices.begin(), pi_devices.end(),
[&res](const RT::PiDevice &a_pi_device) {
device sycl_device = detail::createSyclObjFromImpl<device>(
Expand Down
60 changes: 60 additions & 0 deletions sycl/test/config/white_list.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// REQUIRES: cpu
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out

#include <CL/sycl.hpp>
#include <iostream>
#include <cstdlib>
#include <exception>
#include <string>

using namespace cl;

int main() {

// Expected that white list filter is not set
if (getenv("PRINT_DEVICE_INFO")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host()) {
const sycl::device Dev = Plt.get_devices().at(0);
std::string DevName = Dev.get_info<sycl::info::device::name>();
const std::string DevVer =
Dev.get_info<sycl::info::device::driver_version>();
// As device name string will be used as regexp pattern, we need to
// get rid of symbols that can be treated in a special way.
// Replace common special symbols with '.' which matches to any sybmol
for (char &Sym : DevName) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using algorithms instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will replace with:

std::replace_if(DevName.begin(), DevName.end(),
                       [](const char Sym) { return '(' == Sym || ')' == Sym; }, '.');

Speaking honestly, I think loop code is more readable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to a simpler:

std::replace(DevName.begin(), DevName.end(), '(', '.');
std::replace(DevName.begin(), DevName.end(), ')', '.');

since everything is in the cache, it should be fast.

Copy link
Contributor Author

@romanovvlad romanovvlad Nov 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that it's a test but I don't think we should sacrifice performance in order to make C++ code more readable...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is very subjective.

for (char &Sym : DevName) {
  if (')' == Sym || '(' == Sym)
    Sym = '.';
}

IMO this code is more readable than:

std::replace(DevName.begin(), DevName.end(), '(', '.');
std::replace(DevName.begin(), DevName.end(), ')', '.');

if (')' == Sym || '(' == Sym)
Sym = '.';
}
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName
<< "}},DriverVersion:{{" << DevVer << "}}";
return 0;
}
throw std::runtime_error("Non host device is not found");
}

// Expected white list to be set with result from "PRINT_DEVICE_INFO" run
if (getenv("TEST_DEVICE_AVAILABLE")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host()) {
if (Plt.get_devices().size() != 1)
throw std::runtime_error("Expected only one non host device.");

return 0;
}
throw std::runtime_error("Non host device is not found");
}

// Expected white list to be set but empty
if (getenv("TEST_DEVICE_IS_NOT_AVAILABLE")) {
for (const sycl::platform &Plt : sycl::platform::get_platforms())
if (!Plt.is_host())
throw std::runtime_error("Expected no non host device is available");
return 0;
}

throw std::runtime_error("Unhandled situation");
}