Skip to content

Commit ca0ec76

Browse files
[NFCI][SYCL] Refactor device selection in platform_impl.cpp (#12288)
Mostly "early continue" and use the same idioms for similar things.
1 parent ec7fb7c commit ca0ec76

File tree

1 file changed

+113
-154
lines changed

1 file changed

+113
-154
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 113 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -266,59 +266,33 @@ std::vector<int> platform_impl::filterDeviceFilter(
266266
MPlugin->call<PiApiKind::piDeviceGetInfo>(
267267
Device, PI_DEVICE_INFO_TYPE, sizeof(sycl::detail::pi::PiDeviceType),
268268
&PiDevType, nullptr);
269-
// Assumption here is that there is 1-to-1 mapping between PiDevType and
270-
// Sycl device type for GPU, CPU, and ACC.
271-
info::device_type DeviceType = pi::cast<info::device_type>(PiDevType);
272269

273270
for (const FilterT &Filter : FilterList->get()) {
274271
backend FilterBackend = Filter.Backend.value_or(backend::all);
275-
// First, match the backend entry
276-
if (FilterBackend == Backend || FilterBackend == backend::all) {
277-
info::device_type FilterDevType =
278-
Filter.DeviceType.value_or(info::device_type::all);
279-
// Next, match the device_type entry
280-
if (FilterDevType == info::device_type::all) {
281-
// Last, match the device_num entry
282-
if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) {
283-
if constexpr (is_ods_target) { // dealing with ODS filters
284-
if (!Blacklist[DeviceNum]) { // ensure it is not blacklisted
285-
if (!Filter.IsNegativeTarget) { // is filter positive?
286-
PiDevices[InsertIDx++] = Device;
287-
original_indices.push_back(DeviceNum);
288-
} else {
289-
// Filter is negative and the device matches the filter so
290-
// blacklist the device.
291-
Blacklist[DeviceNum] = true;
292-
}
293-
}
294-
} else { // dealing with SYCL_DEVICE_FILTER
295-
PiDevices[InsertIDx++] = Device;
296-
original_indices.push_back(DeviceNum);
297-
}
298-
break;
299-
}
300-
301-
} else if (FilterDevType == DeviceType) {
302-
if (!Filter.DeviceNum || DeviceNum == Filter.DeviceNum.value()) {
303-
if constexpr (is_ods_target) {
304-
if (!Blacklist[DeviceNum]) {
305-
if (!Filter.IsNegativeTarget) {
306-
PiDevices[InsertIDx++] = Device;
307-
original_indices.push_back(DeviceNum);
308-
} else {
309-
// Filter is negative and the device matches the filter so
310-
// blacklist the device.
311-
Blacklist[DeviceNum] = true;
312-
}
313-
}
314-
} else {
315-
PiDevices[InsertIDx++] = Device;
316-
original_indices.push_back(DeviceNum);
317-
}
318-
break;
319-
}
272+
// First, match the backend entry.
273+
if (FilterBackend != Backend && FilterBackend != backend::all)
274+
continue;
275+
276+
// Match the device_num entry.
277+
if (Filter.DeviceNum && DeviceNum != Filter.DeviceNum.value())
278+
continue;
279+
280+
if constexpr (is_ods_target) {
281+
// Dealing with ONEAPI_DEVICE_SELECTOR - check for negative filters.
282+
if (Blacklist[DeviceNum]) // already blacklisted.
283+
break;
284+
285+
if (Filter.IsNegativeTarget) {
286+
// Filter is negative and the device matches the filter so
287+
// blacklist the device now.
288+
Blacklist[DeviceNum] = true;
289+
break;
320290
}
321291
}
292+
293+
PiDevices[InsertIDx++] = Device;
294+
original_indices.push_back(DeviceNum);
295+
break;
322296
}
323297
DeviceNum++;
324298
}
@@ -392,116 +366,101 @@ static std::vector<device> amendDeviceAndSubDevices(
392366
bool deviceAdded = false;
393367
for (ods_target target : OdsTargetList->get()) {
394368
backend TargetBackend = target.Backend.value_or(backend::all);
395-
if (PlatformBackend == TargetBackend || TargetBackend == backend::all) {
396-
bool deviceMatch = target.HasDeviceWildCard; // opencl:*
397-
if (target.DeviceType) { // opencl:gpu
398-
deviceMatch = ((target.DeviceType == info::device_type::all) ||
399-
(dev.get_info<info::device::device_type>() ==
400-
target.DeviceType));
401-
402-
} else if (target.DeviceNum) { // opencl:0
403-
deviceMatch = (target.DeviceNum.value() == original_indices[i]);
369+
if (PlatformBackend != TargetBackend && TargetBackend != backend::all)
370+
continue;
371+
372+
bool deviceMatch = target.HasDeviceWildCard; // opencl:*
373+
if (target.DeviceType) { // opencl:gpu
374+
deviceMatch =
375+
((target.DeviceType == info::device_type::all) ||
376+
(dev.get_info<info::device::device_type>() == target.DeviceType));
377+
378+
} else if (target.DeviceNum) { // opencl:0
379+
deviceMatch = (target.DeviceNum.value() == original_indices[i]);
380+
}
381+
382+
if (!deviceMatch)
383+
continue;
384+
385+
// Top level matches. Do we add it, or subdevices, or sub-sub-devices?
386+
bool wantSubDevice = target.SubDeviceNum || target.HasSubDeviceWildCard;
387+
bool supportsSubPartitioning =
388+
(supportsPartitionProperty(dev, partitionProperty) &&
389+
supportsAffinityDomain(dev, partitionProperty, affinityDomain));
390+
bool wantSubSubDevice =
391+
target.SubSubDeviceNum || target.HasSubSubDeviceWildCard;
392+
393+
if (!wantSubDevice) {
394+
// -- Add top level device only.
395+
if (!deviceAdded) {
396+
FinalResult.push_back(dev);
397+
deviceAdded = true;
398+
}
399+
continue;
400+
}
401+
402+
if (!supportsSubPartitioning) {
403+
if (target.DeviceNum ||
404+
(target.DeviceType &&
405+
(target.DeviceType.value() != info::device_type::all))) {
406+
// This device was specifically requested and yet is not
407+
// partitionable.
408+
std::cout << "device is not partitionable: " << target << std::endl;
404409
}
410+
continue;
411+
}
405412

406-
if (deviceMatch) {
407-
// Top level matches. Do we add it, or subdevices, or sub-sub-devices?
408-
bool wantSubDevice =
409-
target.SubDeviceNum || target.HasSubDeviceWildCard;
410-
bool supportsSubPartitioning =
411-
(supportsPartitionProperty(dev, partitionProperty) &&
412-
supportsAffinityDomain(dev, partitionProperty, affinityDomain));
413-
bool wantSubSubDevice =
414-
target.SubSubDeviceNum || target.HasSubSubDeviceWildCard;
415-
416-
// -- Add top level device.
417-
if (!wantSubDevice) {
418-
if (!deviceAdded) {
419-
FinalResult.push_back(dev);
420-
deviceAdded = true;
421-
}
422-
} else {
423-
if (!supportsSubPartitioning) {
424-
if (target.DeviceNum ||
425-
(target.DeviceType &&
426-
(target.DeviceType.value() != info::device_type::all))) {
427-
// This device was specifically requested and yet is not
428-
// partitionable.
429-
std::cout << "device is not partitionable: " << target
430-
<< std::endl;
431-
}
432-
continue;
433-
}
434-
// -- Add sub sub device.
435-
if (wantSubSubDevice) {
436-
437-
auto subDevicesToPartition =
438-
dev.create_sub_devices<partitionProperty>(affinityDomain);
439-
if (target.SubDeviceNum) {
440-
if (subDevicesToPartition.size() >
441-
target.SubDeviceNum.value()) {
442-
subDevicesToPartition[0] =
443-
subDevicesToPartition[target.SubDeviceNum.value()];
444-
subDevicesToPartition.resize(1);
445-
} else {
446-
std::cout << "subdevice index out of bounds: " << target
447-
<< std::endl;
448-
continue;
449-
}
450-
}
451-
for (device subDev : subDevicesToPartition) {
452-
bool supportsSubSubPartitioning =
453-
(supportsPartitionProperty(subDev, partitionProperty) &&
454-
supportsAffinityDomain(subDev, partitionProperty,
455-
affinityDomain));
456-
if (!supportsSubSubPartitioning) {
457-
if (target.SubDeviceNum) {
458-
// Parent subdevice was specifically requested, yet is not
459-
// partitionable.
460-
std::cout << "sub-device is not partitionable: " << target
461-
<< std::endl;
462-
}
463-
continue;
464-
}
465-
// Allright, lets get them sub-sub-devices.
466-
auto subSubDevices =
467-
subDev.create_sub_devices<partitionProperty>(
468-
affinityDomain);
469-
if (target.HasSubSubDeviceWildCard) {
470-
FinalResult.insert(FinalResult.end(), subSubDevices.begin(),
471-
subSubDevices.end());
472-
} else {
473-
if (subSubDevices.size() > target.SubSubDeviceNum.value()) {
474-
FinalResult.push_back(
475-
subSubDevices[target.SubSubDeviceNum.value()]);
476-
} else {
477-
std::cout
478-
<< "sub-sub-device index out of bounds: " << target
479-
<< std::endl;
480-
}
481-
}
482-
}
483-
} else if (wantSubDevice) {
484-
auto subDevices = dev.create_sub_devices<
485-
info::partition_property::partition_by_affinity_domain>(
486-
affinityDomain);
487-
if (target.HasSubDeviceWildCard) {
488-
FinalResult.insert(FinalResult.end(), subDevices.begin(),
489-
subDevices.end());
490-
} else {
491-
if (subDevices.size() > target.SubDeviceNum.value()) {
492-
FinalResult.push_back(
493-
subDevices[target.SubDeviceNum.value()]);
494-
} else {
495-
std::cout << "subdevice index out of bounds: " << target
496-
<< std::endl;
497-
}
498-
}
499-
}
413+
auto subDevices = dev.create_sub_devices<
414+
info::partition_property::partition_by_affinity_domain>(
415+
affinityDomain);
416+
if (target.SubDeviceNum) {
417+
if (subDevices.size() <= target.SubDeviceNum.value()) {
418+
std::cout << "subdevice index out of bounds: " << target << std::endl;
419+
continue;
420+
}
421+
subDevices[0] = subDevices[target.SubDeviceNum.value()];
422+
subDevices.resize(1);
423+
}
424+
425+
if (!wantSubSubDevice) {
426+
// -- Add sub device(s) only.
427+
FinalResult.insert(FinalResult.end(), subDevices.begin(),
428+
subDevices.end());
429+
continue;
430+
}
431+
432+
// -- Add sub sub device(s).
433+
for (device subDev : subDevices) {
434+
bool supportsSubSubPartitioning =
435+
(supportsPartitionProperty(subDev, partitionProperty) &&
436+
supportsAffinityDomain(subDev, partitionProperty, affinityDomain));
437+
if (!supportsSubSubPartitioning) {
438+
if (target.SubDeviceNum) {
439+
// Parent subdevice was specifically requested, yet is not
440+
// partitionable.
441+
std::cout << "sub-device is not partitionable: " << target
442+
<< std::endl;
500443
}
501-
} // /if deviceMatch
444+
continue;
445+
}
446+
447+
// Allright, lets get them sub-sub-devices.
448+
auto subSubDevices =
449+
subDev.create_sub_devices<partitionProperty>(affinityDomain);
450+
if (target.SubSubDeviceNum) {
451+
if (subSubDevices.size() <= target.SubSubDeviceNum.value()) {
452+
std::cout << "sub-sub-device index out of bounds: " << target
453+
<< std::endl;
454+
continue;
455+
}
456+
subSubDevices[0] = subSubDevices[target.SubSubDeviceNum.value()];
457+
subSubDevices.resize(1);
458+
}
459+
FinalResult.insert(FinalResult.end(), subSubDevices.begin(),
460+
subSubDevices.end());
502461
}
503-
} // /for
504-
} // /for
462+
}
463+
}
505464
return FinalResult;
506465
}
507466

0 commit comments

Comments
 (0)