@@ -583,8 +583,7 @@ static const char *getUrDeviceTarget(const char *URDeviceTarget) {
583
583
}
584
584
585
585
static bool compatibleWithDevice (RTDeviceBinaryImage *BinImage,
586
- const device &Dev) {
587
- detail::device_impl &DeviceImpl = *detail::getSyclObjImpl (Dev);
586
+ const device_impl &DeviceImpl) {
588
587
auto &Adapter = DeviceImpl.getAdapter ();
589
588
590
589
const ur_device_handle_t &URDeviceHandle = DeviceImpl.getHandleRef ();
@@ -621,7 +620,7 @@ bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
621
620
}
622
621
623
622
bool ProgramManager::isSpecialDeviceImageShouldBeUsed (
624
- RTDeviceBinaryImage *BinImage, const device &Dev ) {
623
+ RTDeviceBinaryImage *BinImage, const device_impl &DeviceImpl ) {
625
624
// Decide whether a devicelib image should be used.
626
625
int Bfloat16DeviceLibVersion = -1 ;
627
626
if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
@@ -640,7 +639,6 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
640
639
// more devicelib images in this way.
641
640
enum { DEVICELIB_FALLBACK = 0 , DEVICELIB_NATIVE };
642
641
ur_bool_t NativeBF16Supported = false ;
643
- detail::device_impl &DeviceImpl = *detail::getSyclObjImpl (Dev);
644
642
ur_result_t CallSuccessful =
645
643
DeviceImpl.getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
646
644
DeviceImpl.getHandleRef (),
@@ -658,15 +656,15 @@ bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
658
656
return false ;
659
657
}
660
658
661
- static bool checkLinkingSupport (const device &Dev ,
659
+ static bool checkLinkingSupport (const device_impl &DeviceImpl ,
662
660
const RTDeviceBinaryImage &Img) {
663
661
const char *Target = Img.getRawData ().DeviceTargetSpec ;
664
662
// TODO replace with extension checks once implemented in UR.
665
663
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64) == 0 ) {
666
664
return true ;
667
665
}
668
666
if (strcmp (Target, __SYCL_DEVICE_BINARY_TARGET_SPIRV64_GEN) == 0 ) {
669
- return Dev .is_gpu () && Dev. get_backend () == backend::opencl;
667
+ return DeviceImpl .is_gpu () && DeviceImpl. getBackend () == backend::opencl;
670
668
}
671
669
return false ;
672
670
}
@@ -701,7 +699,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
701
699
HandledSymbols.insert (ISProp->Name );
702
700
}
703
701
ur::DeviceBinaryType Format = MainImg.getFormat ();
704
- if (!WorkList.empty () && !checkLinkingSupport (Dev, MainImg))
702
+ if (!WorkList.empty () &&
703
+ !checkLinkingSupport (*getSyclObjImpl (Dev).get (), MainImg))
705
704
throw exception (make_error_code (errc::feature_not_supported),
706
705
" Cannot resolve external symbols, linking is unsupported "
707
706
" for the backend" );
@@ -715,10 +714,10 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
715
714
RTDeviceBinaryImage *Img = It->second ;
716
715
if (Img->getFormat () != Format ||
717
716
!doesDevSupportDeviceRequirements (Dev, *Img) ||
718
- !compatibleWithDevice (Img, Dev))
717
+ !compatibleWithDevice (Img, * getSyclObjImpl ( Dev). get () ))
719
718
continue ;
720
719
if (isSpecialDeviceImage (Img) &&
721
- !isSpecialDeviceImageShouldBeUsed (Img, Dev))
720
+ !isSpecialDeviceImageShouldBeUsed (Img, * getSyclObjImpl ( Dev). get () ))
722
721
continue ;
723
722
DeviceImagesToLink.insert (Img);
724
723
Found = true ;
@@ -2415,14 +2414,14 @@ kernel_id ProgramManager::getSYCLKernelID(KernelNameStrRefT KernelName) {
2415
2414
" No kernel found with the specified name" );
2416
2415
}
2417
2416
2418
- bool ProgramManager::hasCompatibleImage (const device &Dev ) {
2417
+ bool ProgramManager::hasCompatibleImage (const device_impl &DeviceImpl ) {
2419
2418
std::lock_guard<std::mutex> Guard (m_KernelIDsMutex);
2420
2419
2421
2420
return std::any_of (
2422
2421
m_BinImg2KernelIDs.cbegin (), m_BinImg2KernelIDs.cend (),
2423
2422
[&](std::pair<RTDeviceBinaryImage *,
2424
2423
std::shared_ptr<std::vector<kernel_id>>>
2425
- Elem) { return compatibleWithDevice (Elem.first , Dev ); });
2424
+ Elem) { return compatibleWithDevice (Elem.first , DeviceImpl ); });
2426
2425
}
2427
2426
2428
2427
std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs () {
@@ -2555,7 +2554,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
2555
2554
RTDeviceBinaryImage *BinImage, const context &Ctx, const device &Dev) {
2556
2555
const bundle_state ImgState = getBinImageState (BinImage);
2557
2556
2558
- assert (compatibleWithDevice (BinImage, Dev));
2557
+ assert (compatibleWithDevice (BinImage, * getSyclObjImpl ( Dev). get () ));
2559
2558
2560
2559
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
2561
2560
// Collect kernel names for the image.
@@ -2640,7 +2639,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2640
2639
KernelImageMap.insert ({KernelID, {}});
2641
2640
2642
2641
for (RTDeviceBinaryImage *BinImage : BinImages) {
2643
- if (!compatibleWithDevice (BinImage, Dev) ||
2642
+ if (!compatibleWithDevice (BinImage, * getSyclObjImpl ( Dev). get () ) ||
2644
2643
!doesDevSupportDeviceRequirements (Dev, *BinImage))
2645
2644
continue ;
2646
2645
0 commit comments