@@ -1254,6 +1254,26 @@ void CheckJITCompilationForImage(const RTDeviceBinaryImage *const &Image,
1254
1254
}
1255
1255
}
1256
1256
1257
+ const char *getArchName (const device &Device) {
1258
+ namespace syclex = sycl::ext::oneapi::experimental;
1259
+ auto Arch = Device.get_info <syclex::info::device::architecture>();
1260
+ switch (Arch) {
1261
+ #define __SYCL_ARCHITECTURE (ARCH, VAL ) \
1262
+ case syclex::architecture::ARCH: \
1263
+ return #ARCH;
1264
+ #define __SYCL_ARCHITECTURE_ALIAS (ARCH, VAL )
1265
+ #include < sycl/ext/oneapi/experimental/architectures.def>
1266
+ #undef __SYCL_ARCHITECTURE
1267
+ #undef __SYCL_ARCHITECTURE_ALIAS
1268
+ }
1269
+ return " unknown" ;
1270
+ }
1271
+
1272
+ sycl_device_binary getRawImg (RTDeviceBinaryImage *Img) {
1273
+ return reinterpret_cast <sycl_device_binary>(
1274
+ const_cast <sycl_device_binary>(&Img->getRawData ()));
1275
+ }
1276
+
1257
1277
template <typename StorageKey>
1258
1278
RTDeviceBinaryImage *getBinImageFromMultiMap (
1259
1279
const std::unordered_multimap<StorageKey, RTDeviceBinaryImage *> &ImagesSet,
@@ -1262,16 +1282,51 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
1262
1282
if (ItBegin == ItEnd)
1263
1283
return nullptr ;
1264
1284
1265
- std::vector<sycl_device_binary> RawImgs (std::distance (ItBegin, ItEnd));
1266
- auto It = ItBegin;
1267
- for (unsigned I = 0 ; It != ItEnd; ++It, ++I)
1268
- RawImgs[I] = reinterpret_cast <sycl_device_binary>(
1269
- const_cast <sycl_device_binary>(&It->second ->getRawData ()));
1285
+ // Here, we aim to select all the device images from the
1286
+ // [ItBegin, ItEnd) range that are AOT compiled for Device
1287
+ // (checked using info::device::architecture) or JIT compiled.
1288
+ // This selection will then be passed to urDeviceSelectBinary
1289
+ // for final selection.
1290
+ std::string_view ArchName = getArchName (Device);
1291
+ std::vector<RTDeviceBinaryImage *> DeviceFilteredImgs;
1292
+ DeviceFilteredImgs.reserve (std::distance (ItBegin, ItEnd));
1293
+ for (auto It = ItBegin; It != ItEnd; ++It) {
1294
+ auto PropRange = It->second ->getDeviceRequirements ();
1295
+ auto PropIt =
1296
+ std::find_if (PropRange.begin (), PropRange.end (), [&](const auto &Prop) {
1297
+ return Prop->Name == std::string_view (" compile_target" );
1298
+ });
1299
+ auto AddImg = [&]() { DeviceFilteredImgs.push_back (It->second ); };
1270
1300
1271
- std::vector<ur_device_binary_t > UrBinaries (RawImgs.size ());
1272
- for (uint32_t BinaryCount = 0 ; BinaryCount < RawImgs.size (); BinaryCount++) {
1273
- UrBinaries[BinaryCount].pDeviceTargetSpec =
1274
- getUrDeviceTarget (RawImgs[BinaryCount]->DeviceTargetSpec );
1301
+ // Device image has no compile_target property, so it is JIT compiled.
1302
+ if (PropIt == PropRange.end ()) {
1303
+ AddImg ();
1304
+ continue ;
1305
+ }
1306
+
1307
+ // Device image has the compile_target property, so it is AOT compiled for
1308
+ // some device, check if that architecture is Device's architecture.
1309
+ auto CompileTargetByteArray = DeviceBinaryProperty (*PropIt).asByteArray ();
1310
+ CompileTargetByteArray.dropBytes (8 );
1311
+ std::string_view CompileTarget (
1312
+ reinterpret_cast <const char *>(&CompileTargetByteArray[0 ]),
1313
+ CompileTargetByteArray.size ());
1314
+ // Note: there are no explicit targets for CPUs, so on x86_64,
1315
+ // so we use a spir64_x86_64 compile target image.
1316
+ if ((ArchName == CompileTarget) ||
1317
+ (ArchName == " x86_64" && CompileTarget == " spir64_x86_64" )) {
1318
+ AddImg ();
1319
+ }
1320
+ }
1321
+
1322
+ if (DeviceFilteredImgs.empty ())
1323
+ return nullptr ;
1324
+
1325
+ std::vector<ur_device_binary_t > UrBinaries (DeviceFilteredImgs.size ());
1326
+ for (uint32_t BinaryCount = 0 ; BinaryCount < DeviceFilteredImgs.size ();
1327
+ BinaryCount++) {
1328
+ UrBinaries[BinaryCount].pDeviceTargetSpec = getUrDeviceTarget (
1329
+ getRawImg (DeviceFilteredImgs[BinaryCount])->DeviceTargetSpec );
1275
1330
}
1276
1331
1277
1332
uint32_t ImgInd = 0 ;
@@ -1280,8 +1335,7 @@ RTDeviceBinaryImage *getBinImageFromMultiMap(
1280
1335
getSyclObjImpl (Context)->getPlugin ()->call (
1281
1336
urDeviceSelectBinary, getSyclObjImpl (Device)->getHandleRef (),
1282
1337
UrBinaries.data (), UrBinaries.size (), &ImgInd);
1283
- std::advance (ItBegin, ImgInd);
1284
- return ItBegin->second ;
1338
+ return DeviceFilteredImgs[ImgInd];
1285
1339
}
1286
1340
1287
1341
RTDeviceBinaryImage &
@@ -1310,10 +1364,8 @@ ProgramManager::getDeviceImage(const std::string &KernelName,
1310
1364
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
1311
1365
if (auto KernelId = m_KernelName2KernelIDs.find (KernelName);
1312
1366
KernelId != m_KernelName2KernelIDs.end ()) {
1313
- // Kernel ID presence guarantees that we have bin image in the storage.
1314
1367
Img = getBinImageFromMultiMap (m_KernelIDs2BinImage, KernelId->second ,
1315
1368
Context, Device);
1316
- assert (Img && " No binary image found for kernel id" );
1317
1369
} else {
1318
1370
Img = getBinImageFromMultiMap (m_ServiceKernels, KernelName, Context,
1319
1371
Device);
0 commit comments