Skip to content

Commit 7aa5be0

Browse files
bso-intelbaderromanovvlad
authored
[SYCL] Make device ids unique per backend (#4247)
* [SYCL] Make device ids unique per backend We decided to make device id numbers unique per backend. Also, by adding the device_type into each device prefix listing in sycl-ls, the user can easily set SYCL_DEVICE_FILTER correctly. Future work: refactor devices and platforms cache to optimize the device retrieval. Signed-off-by: Byoungro So <[email protected]> Co-authored-by: Alexey Bader <[email protected]> Co-authored-by: Romanov Vlad <[email protected]>
1 parent 987427a commit 7aa5be0

File tree

7 files changed

+200
-101
lines changed

7 files changed

+200
-101
lines changed

sycl/doc/EnvironmentVariables.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,25 @@ subject to change. Do not rely on these variables in production code.
5757

5858
This environment variable limits the SYCL RT to use only a subset of the system's devices. Setting this environment variable affects all of the device query functions (`platform::get_devices()` and `platform::get_platforms()`) and all of the device selectors.
5959

60-
The value of this environment variable is a comma separated list of filters, where each filter is a triple of the form "`backend:device_type:device_num`" (without the quotes). Each element of the triple is optional, but each filter must have at least one value. Possible values of "backend" are:
61-
- host
60+
The value of this environment variable is a comma separated list of filters, where each filter is a triple of the form "`backend`:`device_type`:`device_num`" (without the quotes). Each element of the triple is optional, but each filter must have at least one value. Possible values of `backend` are:
61+
- `host`
6262
- `level_zero`
63-
- opencl
64-
- cuda
65-
- \*
63+
- `opencl`
64+
- `cuda`
65+
- `*`
6666

67-
Possible values of "`device_type`" are:
68-
- host
69-
- cpu
70-
- gpu
71-
- acc
72-
- \*
67+
Possible values of `device_type` are:
68+
- `host`
69+
- `cpu`
70+
- `gpu`
71+
- `acc`
72+
- `*`
7373

74-
`Device_num` is an integer that indexes the enumeration of devices from the sycl-ls utility tool, where the first device in that enumeration has index zero in each backend. For example, `SYCL_DEVICE_FILTER`=2 will return all devices with index '2' from all different backends. If multiple devices satisfy this device number (e.g., GPU and CPU devices can be assigned device number '2'), then default_selector will choose the device with the highest heuristic point.
74+
`device_num` is an integer that indexes the enumeration of devices from the sycl-ls utility tool, where the first device in that enumeration has index zero in each backend. For example, `SYCL_DEVICE_FILTER=2` will return all devices with index '2' from all different backends. If multiple devices satisfy this device number (e.g., GPU and CPU devices can be assigned device number '2'), then default_selector will choose the device with the highest heuristic point. When `SYCL_DEVICE_ALLOWLIST` is set, it is applied before enumerating devices and affects `device_num` values.
7575

76-
Assuming a filter has all three elements of the triple, it selects only those devices that come from the given backend, have the specified device type, AND have the given device index. If more than one filter is specified, the RT is restricted to the union of devices selected by all filters. The RT does not include the "host" backend and the host device automatically unless one of the filters explicitly specifies the "host" device type. Therefore, `SYCL_DEVICE_FILTER`=host should be set to enforce SYCL to use the host device only.
76+
Assuming a filter has all three elements of the triple, it selects only those devices that come from the given backend, have the specified device type, AND have the given device index. If more than one filter is specified, the RT is restricted to the union of devices selected by all filters. The RT does not include the `host` backend and the `host` device automatically unless one of the filters explicitly specifies the `host` device type. Therefore, `SYCL_DEVICE_FILTER=host` should be set to enforce SYCL to use the `host` device only.
7777

78-
Note that all device selectors will throw an exception if the filtered list of devices does not include a device that satisfies the selector. For instance, `SYCL_DEVICE_FILTER`=cpu,level_zero will cause host_selector() to throw an exception. `SYCL_DEVICE_FILTER` also limits loading only specified plugins into the SYCL RT. In particular, `SYCL_DEVICE_FILTER`=level_zero will cause the cpu_selector to throw an exception since SYCL RT will only load the level_zero backend which does not support any CPU devices at this time. When multiple devices satisfy the filter (e..g, `SYCL_DEVICE_FILTER`=gpu), only one of them will be selected.
78+
Note that all device selectors will throw an exception if the filtered list of devices does not include a device that satisfies the selector. For instance, `SYCL_DEVICE_FILTER=cpu,level_zero` will cause `host_selector()` to throw an exception. `SYCL_DEVICE_FILTER` also limits loading only specified plugins into the SYCL RT. In particular, `SYCL_DEVICE_FILTER=level_zero` will cause the `cpu_selector` to throw an exception since SYCL RT will only load the `level_zero` backend which does not support any CPU devices at this time. When multiple devices satisfy the filter (e..g, `SYCL_DEVICE_FILTER=gpu`), only one of them will be selected.
7979

8080
### `SYCL_PRINT_EXECUTION_GRAPH` Options
8181

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ template <class To, class From> To cast(From value);
154154
extern std::shared_ptr<plugin> GlobalPlugin;
155155

156156
// Performs PI one-time initialization.
157-
const std::vector<plugin> &initialize();
157+
std::vector<plugin> &initialize();
158158

159159
// Get the plugin serving given backend.
160160
template <backend BE> __SYCL_EXPORT const plugin &getPlugin();

sycl/source/detail/device_filter.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,65 +12,80 @@
1212
#include <detail/device_impl.hpp>
1313

1414
#include <cstring>
15+
#include <string_view>
1516

1617
__SYCL_INLINE_NAMESPACE(cl) {
1718
namespace sycl {
1819
namespace detail {
1920

21+
std::vector<std::string_view> tokenize(const std::string &Filter,
22+
const std::string &Delim) {
23+
std::vector<std::string_view> Tokens;
24+
size_t Pos = 0;
25+
size_t LastPos = 0;
26+
27+
while ((Pos = Filter.find(Delim, LastPos)) != std::string::npos) {
28+
std::string_view Tok(Filter.data() + LastPos, (Pos - LastPos));
29+
30+
if (!Tok.empty()) {
31+
Tokens.push_back(Tok);
32+
}
33+
// move the search starting index
34+
LastPos = Pos + 1;
35+
}
36+
37+
// Add remainder if any
38+
if (LastPos < Filter.size()) {
39+
std::string_view Tok(Filter.data() + LastPos, Filter.size() - LastPos);
40+
Tokens.push_back(Tok);
41+
}
42+
return Tokens;
43+
}
44+
2045
device_filter::device_filter(const std::string &FilterString) {
21-
size_t Cursor = 0;
22-
size_t ColonPos = 0;
23-
auto findElement = [&](auto Element) {
24-
size_t Found = FilterString.find(Element.first, Cursor);
25-
if (Found == std::string::npos)
26-
return false;
27-
Cursor = Found;
28-
return true;
46+
std::vector<std::string_view> Tokens = tokenize(FilterString, ":");
47+
size_t TripleValueID = 0;
48+
49+
auto FindElement = [&](auto Element) {
50+
return std::string::npos != Tokens[TripleValueID].find(Element.first);
2951
};
3052

3153
// Handle the optional 1st field of the filter, backend
3254
// Check if the first entry matches with a known backend type
3355
auto It = std::find_if(std::begin(getSyclBeMap()), std::end(getSyclBeMap()),
34-
findElement);
56+
FindElement);
3557
// If no match is found, set the backend type backend::all
3658
// which actually means 'any backend' will be a match.
3759
if (It == getSyclBeMap().end())
3860
Backend = backend::all;
3961
else {
4062
Backend = It->second;
41-
ColonPos = FilterString.find(":", Cursor);
42-
if (ColonPos != std::string::npos)
43-
Cursor = ColonPos + 1;
44-
else
45-
Cursor = Cursor + It->first.size();
63+
TripleValueID++;
4664
}
65+
4766
// Handle the optional 2nd field of the filter - device type.
4867
// Check if the 2nd entry matches with any known device type.
49-
if (Cursor >= FilterString.size()) {
68+
if (TripleValueID >= Tokens.size()) {
5069
DeviceType = info::device_type::all;
5170
} else {
5271
auto Iter = std::find_if(std::begin(getSyclDeviceTypeMap()),
53-
std::end(getSyclDeviceTypeMap()), findElement);
72+
std::end(getSyclDeviceTypeMap()), FindElement);
5473
// If no match is found, set device_type 'all',
5574
// which actually means 'any device_type' will be a match.
5675
if (Iter == getSyclDeviceTypeMap().end())
5776
DeviceType = info::device_type::all;
5877
else {
5978
DeviceType = Iter->second;
60-
ColonPos = FilterString.find(":", Cursor);
61-
if (ColonPos != std::string::npos)
62-
Cursor = ColonPos + 1;
63-
else
64-
Cursor = Cursor + Iter->first.size();
79+
TripleValueID++;
6580
}
6681
}
6782

6883
// Handle the optional 3rd field of the filter, device number
6984
// Try to convert the remaining string to an integer.
7085
// If succeessful, the converted integer is the desired device num.
71-
if (Cursor < FilterString.size()) {
86+
if (TripleValueID < Tokens.size()) {
7287
try {
73-
DeviceNum = stoi(FilterString.substr(Cursor));
88+
DeviceNum = std::stoi(Tokens[TripleValueID].data());
7489
HasDeviceNum = true;
7590
} catch (...) {
7691
std::string Message =

sycl/source/detail/pi.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ getPluginOpaqueData<cl::sycl::backend::esimd_cpu>(void *);
7878

7979
namespace pi {
8080

81-
static void initializePlugins(std::vector<plugin> *Plugins);
81+
static void initializePlugins(std::vector<plugin> &Plugins);
8282

8383
bool XPTIInitDone = false;
8484

@@ -369,17 +369,17 @@ bool trace(TraceLevel Level) {
369369
}
370370

371371
// Initializes all available Plugins.
372-
const std::vector<plugin> &initialize() {
372+
std::vector<plugin> &initialize() {
373373
static std::once_flag PluginsInitDone;
374-
375-
std::call_once(PluginsInitDone, []() {
376-
initializePlugins(&GlobalHandler::instance().getPlugins());
374+
// std::call_once is blocking all other threads if a thread is already
375+
// creating a vector of plugins. So, no additional lock is needed.
376+
std::call_once(PluginsInitDone, [&]() {
377+
initializePlugins(GlobalHandler::instance().getPlugins());
377378
});
378-
379379
return GlobalHandler::instance().getPlugins();
380380
}
381381

382-
static void initializePlugins(std::vector<plugin> *Plugins) {
382+
static void initializePlugins(std::vector<plugin> &Plugins) {
383383
std::vector<std::pair<std::string, backend>> PluginNames = findPlugins();
384384

385385
if (PluginNames.empty() && trace(PI_TRACE_ALL))
@@ -438,7 +438,7 @@ static void initializePlugins(std::vector<plugin> *Plugins) {
438438
GlobalPlugin = std::make_shared<plugin>(PluginInformation,
439439
backend::level_zero, Library);
440440
}
441-
Plugins->emplace_back(
441+
Plugins.emplace_back(
442442
plugin(PluginInformation, PluginNames[I].second, Library));
443443
if (trace(TraceLevel::PI_TRACE_BASIC))
444444
std::cerr << "SYCL_PI_TRACE[basic]: "

sycl/source/detail/platform_impl.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,30 @@ static bool IsBannedPlatform(platform Platform) {
9595

9696
std::vector<platform> platform_impl::get_platforms() {
9797
std::vector<platform> Platforms;
98-
const std::vector<plugin> &Plugins = RT::initialize();
99-
98+
std::vector<plugin> &Plugins = RT::initialize();
10099
info::device_type ForcedType = detail::get_forced_type();
101-
for (unsigned int i = 0; i < Plugins.size(); i++) {
102-
100+
for (plugin &Plugin : Plugins) {
103101
pi_uint32 NumPlatforms = 0;
104102
// Move to the next plugin if the plugin fails to initialize.
105103
// This way platforms from other plugins get a chance to be discovered.
106-
if (Plugins[i].call_nocheck<PiApiKind::piPlatformsGet>(
104+
if (Plugin.call_nocheck<PiApiKind::piPlatformsGet>(
107105
0, nullptr, &NumPlatforms) != PI_SUCCESS)
108106
continue;
109107

110108
if (NumPlatforms) {
111109
std::vector<RT::PiPlatform> PiPlatforms(NumPlatforms);
112-
if (Plugins[i].call_nocheck<PiApiKind::piPlatformsGet>(
110+
if (Plugin.call_nocheck<PiApiKind::piPlatformsGet>(
113111
NumPlatforms, PiPlatforms.data(), nullptr) != PI_SUCCESS)
114112
return Platforms;
115113

116114
for (const auto &PiPlatform : PiPlatforms) {
117115
platform Platform = detail::createSyclObjFromImpl<platform>(
118-
getOrMakePlatformImpl(PiPlatform, Plugins[i]));
116+
getOrMakePlatformImpl(PiPlatform, Plugin));
117+
{
118+
std::lock_guard<std::mutex> Guard(*Plugin.getPluginMutex());
119+
// insert PiPlatform into the Plugin
120+
Plugin.getPlatformId(PiPlatform);
121+
}
119122
// Skip platforms which do not contain requested device types
120123
if (!Platform.get_devices(ForcedType).empty() &&
121124
!IsBannedPlatform(Platform))
@@ -141,14 +144,26 @@ std::vector<platform> platform_impl::get_platforms() {
141144
// This function matches devices in the order of backend, device_type, and
142145
// device_num.
143146
static void filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
144-
const plugin &Plugin) {
147+
RT::PiPlatform Platform) {
145148
device_filter_list *FilterList = SYCLConfig<SYCL_DEVICE_FILTER>::get();
146149
if (!FilterList)
147150
return;
148151

152+
std::vector<plugin> &Plugins = RT::initialize();
153+
auto It =
154+
std::find_if(Plugins.begin(), Plugins.end(), [Platform](plugin &Plugin) {
155+
return Plugin.containsPiPlatform(Platform);
156+
});
157+
if (It == Plugins.end())
158+
return;
159+
160+
plugin &Plugin = *It;
149161
backend Backend = Plugin.getBackend();
150162
int InsertIDx = 0;
151-
int DeviceNum = 0;
163+
// DeviceIds should be given consecutive numbers across platforms in the same
164+
// backend
165+
std::lock_guard<std::mutex> Guard(*Plugin.getPluginMutex());
166+
int DeviceNum = Plugin.getStartingDeviceId(Platform);
152167
for (RT::PiDevice Device : PiDevices) {
153168
RT::PiDeviceType PiDevType;
154169
Plugin.call<PiApiKind::piDeviceGetInfo>(Device, PI_DEVICE_INFO_TYPE,
@@ -181,6 +196,10 @@ static void filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
181196
DeviceNum++;
182197
}
183198
PiDevices.resize(InsertIDx);
199+
// remember the last backend that has gone through this filter function
200+
// to assign a unique device id number across platforms that belong to
201+
// the same backend. For example, opencl:cpu:0, opencl:acc:1, opencl:gpu:2
202+
Plugin.setLastDeviceId(Platform, DeviceNum);
184203
}
185204

186205
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
@@ -237,12 +256,12 @@ platform_impl::get_devices(info::device_type DeviceType) const {
237256

238257
// Filter out devices that are not present in the SYCL_DEVICE_ALLOWLIST
239258
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get())
240-
applyAllowList(PiDevices, MPlatform, this->getPlugin());
259+
applyAllowList(PiDevices, MPlatform, Plugin);
241260

242261
// Filter out devices that are not compatible with SYCL_DEVICE_FILTER
243-
filterDeviceFilter(PiDevices, Plugin);
262+
filterDeviceFilter(PiDevices, MPlatform);
244263

245-
PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, *MPlugin);
264+
PlatformImplPtr PlatformImpl = getOrMakePlatformImpl(MPlatform, Plugin);
246265
std::transform(
247266
PiDevices.begin(), PiDevices.end(), std::back_inserter(Res),
248267
[PlatformImpl](const RT::PiDevice &PiDevice) -> device {

sycl/source/detail/plugin.hpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ auto packCallArguments(ArgsT &&... Args) {
8989
class plugin {
9090
public:
9191
plugin() = delete;
92-
9392
plugin(RT::PiPlugin Plugin, backend UseBackend, void *LibraryHandle)
9493
: MPlugin(Plugin), MBackend(UseBackend), MLibraryHandle(LibraryHandle),
95-
TracingMutex(std::make_shared<std::mutex>()) {}
94+
TracingMutex(std::make_shared<std::mutex>()),
95+
MPluginMutex(std::make_shared<std::mutex>()) {}
9696

9797
plugin &operator=(const plugin &) = default;
9898
plugin(const plugin &) = default;
@@ -184,11 +184,59 @@ class plugin {
184184
void *getLibraryHandle() { return MLibraryHandle; }
185185
int unload() { return RT::unloadPlugin(MLibraryHandle); }
186186

187+
// return the index of PiPlatforms.
188+
// If not found, add it and return its index.
189+
// The function is expected to be called in a thread safe manner.
190+
int getPlatformId(RT::PiPlatform Platform) {
191+
auto It = std::find(PiPlatforms.begin(), PiPlatforms.end(), Platform);
192+
if (It != PiPlatforms.end())
193+
return It - PiPlatforms.begin();
194+
195+
PiPlatforms.push_back(Platform);
196+
LastDeviceIds.push_back(0);
197+
return PiPlatforms.size() - 1;
198+
}
199+
200+
// Device ids are consecutive across platforms within a plugin.
201+
// We need to return the same starting index for the given platform.
202+
// So, instead of returing the last device id of the given platform,
203+
// return the last device id of the predecessor platform.
204+
// The function is expected to be called in a thread safe manner.
205+
int getStartingDeviceId(RT::PiPlatform Platform) {
206+
int PlatformId = getPlatformId(Platform);
207+
if (PlatformId == 0)
208+
return 0;
209+
return LastDeviceIds[PlatformId - 1];
210+
}
211+
212+
// set the id of the last device for the given platform
213+
// The function is expected to be called in a thread safe manner.
214+
void setLastDeviceId(RT::PiPlatform Platform, int Id) {
215+
int PlatformId = getPlatformId(Platform);
216+
LastDeviceIds[PlatformId] = Id;
217+
}
218+
219+
bool containsPiPlatform(RT::PiPlatform Platform) {
220+
auto It = std::find(PiPlatforms.begin(), PiPlatforms.end(), Platform);
221+
return It != PiPlatforms.end();
222+
}
223+
224+
std::shared_ptr<std::mutex> getPluginMutex() { return MPluginMutex; }
225+
187226
private:
188227
RT::PiPlugin MPlugin;
189228
backend MBackend;
190229
void *MLibraryHandle; // the handle returned from dlopen
191230
std::shared_ptr<std::mutex> TracingMutex;
231+
// Mutex to guard PiPlatforms and LastDeviceIds.
232+
// Note that this is a temporary solution until we implement the global
233+
// Device/Platform cache later.
234+
std::shared_ptr<std::mutex> MPluginMutex;
235+
// vector of PiPlatforms that belong to this plugin
236+
std::vector<RT::PiPlatform> PiPlatforms;
237+
// represents the unique ids of the last device of each platform
238+
// index of this vector corresponds to the index in PiPlatforms vector.
239+
std::vector<int> LastDeviceIds;
192240
}; // class plugin
193241
} // namespace detail
194242
} // namespace sycl

0 commit comments

Comments
 (0)