Skip to content

Commit 38e588d

Browse files
authored
[SYCL] Select device image based on compile_target device image property (#14909)
We allow multiple so-called "special" targets to be passed to `-fsycl-targets`, but without extra information SYCL RT wouldn't be able to select the right AOT-compiled device image. #14757 introduced a device image property to specify an exact target for a device image and this patch made runtime honor that property when selecting a device image.
1 parent 2535062 commit 38e588d

File tree

4 files changed

+379
-14
lines changed

4 files changed

+379
-14
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,26 @@ void CheckJITCompilationForImage(const RTDeviceBinaryImage *const &Image,
12541254
}
12551255
}
12561256

1257+
const char *getArchName(const device &Device) {
1258+
namespace syclex = sycl::ext::oneapi::experimental;
1259+
auto Arch = Device.get_info<syclex::info::device::architecture>();
1260+
switch (Arch) {
1261+
#define __SYCL_ARCHITECTURE(ARCH, VAL) \
1262+
case syclex::architecture::ARCH: \
1263+
return #ARCH;
1264+
#define __SYCL_ARCHITECTURE_ALIAS(ARCH, VAL)
1265+
#include <sycl/ext/oneapi/experimental/architectures.def>
1266+
#undef __SYCL_ARCHITECTURE
1267+
#undef __SYCL_ARCHITECTURE_ALIAS
1268+
}
1269+
return "unknown";
1270+
}
1271+
1272+
sycl_device_binary getRawImg(RTDeviceBinaryImage *Img) {
1273+
return reinterpret_cast<sycl_device_binary>(
1274+
const_cast<sycl_device_binary>(&Img->getRawData()));
1275+
}
1276+
12571277
template <typename StorageKey>
12581278
RTDeviceBinaryImage *getBinImageFromMultiMap(
12591279
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
@@ -1262,16 +1282,51 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
12621282
if (ItBegin == ItEnd)
12631283
return nullptr;
12641284

1265-
std::vector<sycl_device_binary> RawImgs(std::distance(ItBegin, ItEnd));
1266-
auto It = ItBegin;
1267-
for (unsigned I = 0; It != ItEnd; ++It, ++I)
1268-
RawImgs[I] = reinterpret_cast<sycl_device_binary>(
1269-
const_cast<sycl_device_binary>(&It->second->getRawData()));
1285+
// Here, we aim to select all the device images from the
1286+
// [ItBegin, ItEnd) range that are AOT compiled for Device
1287+
// (checked using info::device::architecture) or JIT compiled.
1288+
// This selection will then be passed to urDeviceSelectBinary
1289+
// for final selection.
1290+
std::string_view ArchName = getArchName(Device);
1291+
std::vector<RTDeviceBinaryImage *> DeviceFilteredImgs;
1292+
DeviceFilteredImgs.reserve(std::distance(ItBegin, ItEnd));
1293+
for (auto It = ItBegin; It != ItEnd; ++It) {
1294+
auto PropRange = It->second->getDeviceRequirements();
1295+
auto PropIt =
1296+
std::find_if(PropRange.begin(), PropRange.end(), [&](const auto &Prop) {
1297+
return Prop->Name == std::string_view("compile_target");
1298+
});
1299+
auto AddImg = [&]() { DeviceFilteredImgs.push_back(It->second); };
12701300

1271-
std::vector<ur_device_binary_t> UrBinaries(RawImgs.size());
1272-
for (uint32_t BinaryCount = 0; BinaryCount < RawImgs.size(); BinaryCount++) {
1273-
UrBinaries[BinaryCount].pDeviceTargetSpec =
1274-
getUrDeviceTarget(RawImgs[BinaryCount]->DeviceTargetSpec);
1301+
// Device image has no compile_target property, so it is JIT compiled.
1302+
if (PropIt == PropRange.end()) {
1303+
AddImg();
1304+
continue;
1305+
}
1306+
1307+
// Device image has the compile_target property, so it is AOT compiled for
1308+
// some device, check if that architecture is Device's architecture.
1309+
auto CompileTargetByteArray = DeviceBinaryProperty(*PropIt).asByteArray();
1310+
CompileTargetByteArray.dropBytes(8);
1311+
std::string_view CompileTarget(
1312+
reinterpret_cast<const char *>(&CompileTargetByteArray[0]),
1313+
CompileTargetByteArray.size());
1314+
// Note: there are no explicit targets for CPUs, so on x86_64,
1315+
// so we use a spir64_x86_64 compile target image.
1316+
if ((ArchName == CompileTarget) ||
1317+
(ArchName == "x86_64" && CompileTarget == "spir64_x86_64")) {
1318+
AddImg();
1319+
}
1320+
}
1321+
1322+
if (DeviceFilteredImgs.empty())
1323+
return nullptr;
1324+
1325+
std::vector<ur_device_binary_t> UrBinaries(DeviceFilteredImgs.size());
1326+
for (uint32_t BinaryCount = 0; BinaryCount < DeviceFilteredImgs.size();
1327+
BinaryCount++) {
1328+
UrBinaries[BinaryCount].pDeviceTargetSpec = getUrDeviceTarget(
1329+
getRawImg(DeviceFilteredImgs[BinaryCount])->DeviceTargetSpec);
12751330
}
12761331

12771332
uint32_t ImgInd = 0;
@@ -1280,8 +1335,7 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
12801335
getSyclObjImpl(Context)->getPlugin()->call(
12811336
urDeviceSelectBinary, getSyclObjImpl(Device)->getHandleRef(),
12821337
UrBinaries.data(), UrBinaries.size(), &ImgInd);
1283-
std::advance(ItBegin, ImgInd);
1284-
return ItBegin->second;
1338+
return DeviceFilteredImgs[ImgInd];
12851339
}
12861340

12871341
RTDeviceBinaryImage &
@@ -1310,10 +1364,8 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
13101364
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
13111365
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
13121366
KernelId != m_KernelName2KernelIDs.end()) {
1313-
// Kernel ID presence guarantees that we have bin image in the storage.
13141367
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
13151368
Context, Device);
1316-
assert(Img && "No binary image found for kernel id");
13171369
} else {
13181370
Img = getBinImageFromMultiMap(m_ServiceKernels, KernelName, Context,
13191371
Device);

sycl/unittests/helpers/UrImage.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,14 @@ inline UrImage
541541
generateDefaultImage(std::initializer_list<std::string> KernelNames) {
542542
UrPropertySet PropSet;
543543

544-
std::vector<unsigned char> Bin{0, 1, 2, 3, 4, 5}; // Random data
544+
std::string Combined;
545+
for (auto it = KernelNames.begin(); it != KernelNames.end(); ++it) {
546+
if (it != KernelNames.begin())
547+
Combined += ", ";
548+
Combined += *it;
549+
}
550+
std::vector<unsigned char> Bin(Combined.begin(), Combined.end());
551+
Bin.push_back(0);
545552

546553
UrArray<UrOffloadEntry> Entries = makeEmptyKernels(KernelNames);
547554

sycl/unittests/program_manager/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
set(CMAKE_CXX_EXTENSIONS OFF)
22
add_sycl_unittest(ProgramManagerTests OBJECT
3+
CompileTarget.cpp
34
BuildLog.cpp
45
DynamicLinking.cpp
56
itt_annotations.cpp

0 commit comments

Comments
 (0)