@@ -58,6 +58,12 @@ struct DevDescT {
58
58
59
59
const char *devDriverVer = nullptr ;
60
60
int devDriverVerSize = 0 ;
61
+
62
+ const char *platformName = nullptr ;
63
+ int platformNameSize = 0 ;
64
+
65
+ const char *platformVer = nullptr ;
66
+ int platformVerSize = 0 ;
61
67
};
62
68
63
69
static std::vector<DevDescT> getWhiteListDesc () {
@@ -68,6 +74,8 @@ static std::vector<DevDescT> getWhiteListDesc() {
68
74
std::vector<DevDescT> decDescs;
69
75
const char devNameStr[] = " DeviceName" ;
70
76
const char driverVerStr[] = " DriverVersion" ;
77
+ const char platformNameStr[] = " PlatformName" ;
78
+ const char platformVerStr[] = " PlatformVersion" ;
71
79
decDescs.emplace_back ();
72
80
while (' \0 ' != *str) {
73
81
const char **valuePtr = nullptr ;
@@ -78,6 +86,15 @@ static std::vector<DevDescT> getWhiteListDesc() {
78
86
valuePtr = &decDescs.back ().devName ;
79
87
size = &decDescs.back ().devNameSize ;
80
88
str += sizeof (devNameStr) - 1 ;
89
+ } else if (0 ==
90
+ strncmp (platformNameStr, str, sizeof (platformNameStr) - 1 )) {
91
+ valuePtr = &decDescs.back ().platformName ;
92
+ size = &decDescs.back ().platformNameSize ;
93
+ str += sizeof (platformNameStr) - 1 ;
94
+ } else if (0 == strncmp (platformVerStr, str, sizeof (platformVerStr) - 1 )) {
95
+ valuePtr = &decDescs.back ().platformVer ;
96
+ size = &decDescs.back ().platformVerSize ;
97
+ str += sizeof (platformVerStr) - 1 ;
81
98
} else if (0 == strncmp (driverVerStr, str, sizeof (driverVerStr) - 1 )) {
82
99
valuePtr = &decDescs.back ().devDriverVer ;
83
100
size = &decDescs.back ().devDriverVerSize ;
@@ -125,23 +142,43 @@ static std::vector<DevDescT> getWhiteListDesc() {
125
142
return decDescs;
126
143
}
127
144
128
- static void filterWhiteList (vector_class<RT::PiDevice> &pi_devices) {
145
+ static void filterWhiteList (vector_class<RT::PiDevice> &pi_devices,
146
+ RT::PiPlatform pi_platform) {
129
147
const std::vector<DevDescT> whiteList (getWhiteListDesc ());
130
148
if (whiteList.empty ())
131
149
return ;
132
150
151
+ const string_class platformName =
152
+ sycl::detail::get_platform_info<string_class, info::platform::name>::get (
153
+ pi_platform);
154
+
155
+ const string_class platformVer = sycl::detail::get_platform_info<
156
+ string_class, info::platform::version>::get (pi_platform);
157
+
133
158
int insertIDx = 0 ;
134
159
for (RT::PiDevice dev : pi_devices) {
135
160
const string_class devName =
136
- sycl::detail::get_device_info<string_class, info::device::name>::get (dev);
161
+ sycl::detail::get_device_info<string_class, info::device::name>::get (
162
+ dev);
137
163
138
164
const string_class devDriverVer =
139
165
sycl::detail::get_device_info<string_class,
140
166
info::device::driver_version>::get (dev);
141
167
142
168
for (const DevDescT &desc : whiteList) {
143
- // At least device name is required field to consider the filter so far
144
- if (nullptr == desc.devName ||
169
+ if (nullptr != desc.platformName &&
170
+ !std::regex_match (platformName,
171
+ std::regex (std::string (desc.platformName ,
172
+ desc.platformNameSize ))))
173
+ continue ;
174
+
175
+ if (nullptr != desc.platformVer &&
176
+ !std::regex_match (
177
+ platformVer,
178
+ std::regex (std::string (desc.platformVer , desc.platformVerSize ))))
179
+ continue ;
180
+
181
+ if (nullptr != desc.devName &&
145
182
!std::regex_match (
146
183
devName, std::regex (std::string (desc.devName , desc.devNameSize ))))
147
184
continue ;
@@ -179,7 +216,7 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
179
216
180
217
// Filter out devices that are not present in the white list
181
218
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get ())
182
- filterWhiteList (pi_devices);
219
+ filterWhiteList (pi_devices, m_platform );
183
220
184
221
std::for_each (pi_devices.begin (), pi_devices.end (),
185
222
[&res](const RT::PiDevice &a_pi_device) {
0 commit comments