Skip to content

Commit 5c30ab7

Browse files
committed
SYCL_DEVICE_ALLOWLIST fix. WIP
Signed-off-by: Gail Lyons <[email protected]>
1 parent 129ee44 commit 5c30ab7

File tree

2 files changed

+248
-2
lines changed

2 files changed

+248
-2
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ struct DevDescT {
134134
};
135135

136136
static std::vector<DevDescT> getAllowListDesc() {
137-
const char *str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get();
138-
if (!str)
137+
const char *Str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get();
138+
if (!Str)
139139
return {};
140140

141141
std::vector<DevDescT> decDescs;
@@ -144,6 +144,23 @@ static std::vector<DevDescT> getAllowListDesc() {
144144
const char platformNameStr[] = "PlatformName";
145145
const char platformVerStr[] = "PlatformVersion";
146146
decDescs.emplace_back();
147+
148+
std::cout << "Before: " << Str << std::endl;
149+
150+
// Replace common special symbols with '.' which matches to any character
151+
#if 0 // gail
152+
std::string tmp(Str);
153+
std::replace_if(tmp.begin(), tmp.end(),
154+
[](const char sym) { return '(' == sym || ')' == sym; }, '.');
155+
const char * str = tmp.c_str();
156+
#endif //gail
157+
158+
std::string tmp(Str);
159+
std::replace(tmp.begin(), tmp.end(), '(', '.');
160+
std::replace(tmp.begin(), tmp.end(), ')', '.');
161+
const char * str = tmp.c_str();
162+
std::cout << "After : " << str << std::endl;
163+
147164
while ('\0' != *str) {
148165
const char **valuePtr = nullptr;
149166
int *size = nullptr;

sycl/test/config/select_device.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
//
3+
// RUN: env WRITE_DEVICE_INFO=1 %t.out
4+
// RUN: env READ_DEVICE_INFO=1 %t.out
5+
//
6+
// RUN: env WRITE_PLATFORM_INFO=1 %t.out
7+
// RUN: env READ_PLATFORM_INFO=1 %t.out
8+
9+
//==------------ select_device.cpp - SYCL_DEVICE_ALLOWLIST test ------------==//
10+
//
11+
// This test is unusual because it occurs in two phases. The first phase
12+
// will find the GPU platforms, and write them to a file. The second phase
13+
// will read the file, set SYCL_DEVICE_ALLOWLIST, and then find the correct
14+
// platform. SYCL_DEVICE_ALLOWLIST is only evaluated once, the first time
15+
// get_platforms() is called. Setting it later in the application has no
16+
// effect.
17+
//
18+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
19+
// See https://llvm.org/LICENSE.txt for license information.
20+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
21+
//
22+
//===----------------------------------------------------------------------===//
23+
24+
#include <CL/sycl.hpp>
25+
#include <iostream>
26+
#include <fstream>
27+
#include <string>
28+
29+
using namespace cl::sycl;
30+
31+
#ifdef _WIN32
32+
#define setenv(name, value, overwrite) _putenv_s (name, value)
33+
#endif
34+
35+
struct DevDescT {
36+
std::string devName;
37+
std::string devDriverVer;
38+
std::string platName;
39+
std::string platVer;
40+
};
41+
42+
static std::vector<DevDescT> getAllowListDesc(std::string allowList) {
43+
if (allowList.empty())
44+
return {};
45+
46+
std::string deviceName("DeviceName:");
47+
std::string driverVersion("DriverVersion:");
48+
std::string platformName("PlatformName:");
49+
std::string platformVersion("PlatformVersion:");
50+
std::vector<DevDescT> decDescs;
51+
52+
size_t pos = 0;
53+
while ( pos <= allowList.size()) {
54+
decDescs.emplace_back();
55+
56+
if ((allowList.compare(pos, deviceName.size(), deviceName)) == 0) {
57+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
58+
throw std::runtime_error("Malformed device allowlist");
59+
}
60+
size_t start = pos+2;
61+
if ((pos = allowList.find("}},", pos)) == std::string::npos) {
62+
throw std::runtime_error("Malformed device allowlist");
63+
}
64+
decDescs.back().devName = allowList.substr(start, pos-start);
65+
pos = pos+3;
66+
if ((allowList.compare(pos, driverVersion.size(), driverVersion)) == 0) {
67+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
68+
throw std::runtime_error("Malformed device allowlist");
69+
}
70+
start = pos+2;
71+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
72+
throw std::runtime_error("Malformed device allowlist");
73+
}
74+
decDescs.back().devDriverVer = allowList.substr(start, pos-start);
75+
pos = pos+3;
76+
} else {
77+
throw std::runtime_error("Malformed device allowlist");
78+
}
79+
}
80+
else if ((allowList.compare(pos, platformName.size(), platformName)) == 0) {
81+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
82+
throw std::runtime_error("Malformed platform allowlist");
83+
}
84+
size_t start = pos+2;
85+
if ((pos = allowList.find("}},", pos)) == std::string::npos) {
86+
throw std::runtime_error("Malformed platform allowlist");
87+
}
88+
decDescs.back().platName = allowList.substr(start, pos-start);
89+
pos = pos+3;
90+
if ((allowList.compare(pos, platformVersion.size(), platformVersion)) == 0) {
91+
if ((pos = allowList.find("{{", pos)) == std::string::npos) {
92+
throw std::runtime_error("Malformed platform allowlist");
93+
}
94+
start = pos+2;
95+
if ((pos = allowList.find("}}", pos)) == std::string::npos) {
96+
throw std::runtime_error("Malformed platform allowlist");
97+
}
98+
decDescs.back().platVer = allowList.substr(start, pos-start);
99+
pos = pos+3;
100+
} else {
101+
throw std::runtime_error("Malformed platform allowlist");
102+
}
103+
}
104+
else if (allowList.find('|', pos) != std::string::npos) {
105+
pos = allowList.find('|')+1;
106+
while (allowList[pos] == ' ') {
107+
pos++;
108+
}
109+
}
110+
else {
111+
throw std::runtime_error("Malformed platform allowlist");
112+
}
113+
} // while (pos <= allowList.size())
114+
return decDescs;
115+
}
116+
117+
118+
int main() {
119+
bool passed = false;
120+
121+
// Find the GPU devices on this system
122+
if (getenv("WRITE_DEVICE_INFO")) {
123+
std::ofstream fs;
124+
fs.open("select_device_config.txt");
125+
if (fs.is_open()) {
126+
for (const auto &plt : platform::get_platforms()) {
127+
if (!plt.has(aspect::host)){
128+
for (const auto &dev : plt.get_devices()) {
129+
if (dev.has(aspect::gpu)) {
130+
std::string name = dev.get_info<info::device::name>();
131+
std::string ver = dev.get_info<info::device::driver_version>();
132+
fs << "DeviceName:{{" << name
133+
<< "}},DriverVersion:{{" << ver << "}}" << std::endl;
134+
passed=true;
135+
break;
136+
}
137+
}
138+
}
139+
}
140+
fs.close();
141+
}
142+
}
143+
else if (getenv("READ_DEVICE_INFO")) {
144+
std::ifstream fs;
145+
fs.open("select_device_config.txt");
146+
if (fs.is_open()) {
147+
std::string allowlist;
148+
std::getline(fs, allowlist);
149+
if (! allowlist.empty()) {
150+
setenv("SYCL_DEVICE_ALLOWLIST", allowlist.c_str(), 0);
151+
std::vector<DevDescT> components(getAllowListDesc(allowlist));
152+
153+
cl::sycl::queue deviceQueue(gpu_selector{});
154+
device dev = deviceQueue.get_device();
155+
for (const DevDescT &desc : components) {
156+
if ((dev.get_info<info::device::name>() == desc.devName) &&
157+
(dev.get_info<info::device::driver_version>() ==
158+
desc.devDriverVer)) {
159+
passed = true;
160+
}
161+
std::cout << "SYCL_DEVICE_ALLOWLIST=" << allowlist << std::endl;
162+
std::cout << "Device: " << dev.get_info<info::device::name>()
163+
<< std::endl;
164+
std::cout << "DriverVersion: "
165+
<< dev.get_info<info::device::driver_version>()
166+
<< std::endl;
167+
}
168+
}
169+
fs.close();
170+
}
171+
}
172+
// Find the platforms on this system.
173+
if (getenv("WRITE_PLATFORM_INFO")) {
174+
std::ofstream fs;
175+
fs.open("select_device_config.txt");
176+
if (fs.is_open()) {
177+
for (const auto &plt : platform::get_platforms()) {
178+
if (plt.has(aspect::gpu)){
179+
std::string pname = plt.get_info<info::platform::name>();
180+
std::string pver = plt.get_info<info::platform::version>();
181+
fs << "PlatformName:{{" << pname
182+
<< "}},PlatformVersion:{{" << pver << "}}" << std::endl;
183+
passed=true;
184+
break;
185+
}
186+
}
187+
}
188+
fs.close();
189+
}
190+
else if (getenv("READ_PLATFORM_INFO")) {
191+
std::ifstream fs;
192+
fs.open("select_device_config.txt", std::fstream::in);
193+
if (fs.is_open()) {
194+
std::string allowlist;
195+
std::getline(fs, allowlist);
196+
if (! allowlist.empty()) {
197+
setenv("SYCL_DEVICE_ALLOWLIST", allowlist.c_str(), 0);
198+
std::vector<DevDescT> components(getAllowListDesc(allowlist));
199+
200+
for (const auto &plt : platform::get_platforms()) {
201+
if (!plt.has(aspect::host)){
202+
for (const DevDescT &desc : components) {
203+
if ((plt.get_info<info::platform::name>() == desc.platName) &&
204+
(plt.get_info<info::platform::version>() ==
205+
desc.platVer)) {
206+
passed = true;
207+
}
208+
std::cout << "SYCL_DEVICE_ALLOWLIST=" << allowlist << std::endl;
209+
std::cout << "Platform: " << plt.get_info<info::platform::name>()
210+
<< std::endl;
211+
std::cout << "Platform Version: "
212+
<< plt.get_info<info::platform::version>()
213+
<< std::endl;
214+
}
215+
}
216+
}
217+
}
218+
fs.close();
219+
}
220+
}
221+
222+
if (passed) {
223+
std::cout << "Passed." << std::endl;
224+
return 0;
225+
} else {
226+
std:: cout << "Failed." << std::endl;
227+
return 1;
228+
}
229+
}

0 commit comments

Comments
 (0)