@@ -58,10 +58,10 @@ constexpr uint8_t ImageOriginKernelCompiler = 1 << 2;
58
58
class ManagedDeviceGlobalsRegistry {
59
59
public:
60
60
ManagedDeviceGlobalsRegistry (
61
- const std::shared_ptr<context_impl> &ContextImpl ,
62
- const std::string &Prefix, std::vector<std::string> &&DeviceGlobalNames,
61
+ context_impl &ContextImpl, const std::string &Prefix ,
62
+ std::vector<std::string> &&DeviceGlobalNames,
63
63
std::vector<std::unique_ptr<std::byte[]>> &&DeviceGlobalAllocations)
64
- : MContextImpl{ContextImpl}, MPrefix{Prefix},
64
+ : MContextImpl{ContextImpl. shared_from_this () }, MPrefix{Prefix},
65
65
MDeviceGlobalNames{std::move (DeviceGlobalNames)},
66
66
MDeviceGlobalAllocations{std::move (DeviceGlobalAllocations)} {}
67
67
@@ -704,12 +704,11 @@ class device_image_impl {
704
704
assert (MRTCBinInfo);
705
705
assert (MOrigins & ImageOriginKernelCompiler);
706
706
707
- const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
708
- getSyclObjImpl (MContext);
707
+ sycl::detail::context_impl &ContextImpl = *getSyclObjImpl (MContext);
709
708
710
709
for (const auto &SyclDev : Devices) {
711
710
device_impl &DevImpl = *getSyclObjImpl (SyclDev);
712
- if (!ContextImpl-> hasDevice (DevImpl)) {
711
+ if (!ContextImpl. hasDevice (DevImpl)) {
713
712
throw sycl::exception (make_error_code (errc::invalid),
714
713
" device not part of kernel_bundle context" );
715
714
}
@@ -742,7 +741,7 @@ class device_image_impl {
742
741
Devices, BuildOptions, *SourceStrPtr, UrProgram);
743
742
}
744
743
745
- const AdapterPtr &Adapter = ContextImpl-> getAdapter ();
744
+ const AdapterPtr &Adapter = ContextImpl. getAdapter ();
746
745
747
746
if (!FetchedFromCache)
748
747
UrProgram = createProgramFromSource (Devices, BuildOptions, LogPtr);
@@ -752,7 +751,7 @@ class device_image_impl {
752
751
UrProgram, DeviceVec.size (), DeviceVec.data (), XsFlags.c_str ());
753
752
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
754
753
Res = Adapter->call_nocheck <UrApiKind::urProgramBuild>(
755
- ContextImpl-> getHandleRef (), UrProgram, XsFlags.c_str ());
754
+ ContextImpl. getHandleRef (), UrProgram, XsFlags.c_str ());
756
755
}
757
756
Adapter->checkUrResult <errc::build>(Res);
758
757
@@ -796,12 +795,11 @@ class device_image_impl {
796
795
" compile is only available for kernel_bundle<bundle_state::source> "
797
796
" when the source language was sycl." );
798
797
799
- std::shared_ptr<sycl::detail::context_impl> ContextImpl =
800
- getSyclObjImpl (MContext);
798
+ sycl::detail::context_impl &ContextImpl = *getSyclObjImpl (MContext);
801
799
802
800
for (const auto &SyclDev : Devices) {
803
801
detail::device_impl &DevImpl = *getSyclObjImpl (SyclDev);
804
- if (!ContextImpl-> hasDevice (DevImpl)) {
802
+ if (!ContextImpl. hasDevice (DevImpl)) {
805
803
throw sycl::exception (make_error_code (errc::invalid),
806
804
" device not part of kernel_bundle context" );
807
805
}
@@ -873,9 +871,8 @@ class device_image_impl {
873
871
const std::vector<device> Devices,
874
872
const std::vector<sycl::detail::string_view> &BuildOptions,
875
873
const std::string &SourceStr, ur_program_handle_t &UrProgram) const {
876
- const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
877
- getSyclObjImpl (MContext);
878
- const AdapterPtr &Adapter = ContextImpl->getAdapter ();
874
+ sycl::detail::context_impl &ContextImpl = *getSyclObjImpl (MContext);
875
+ const AdapterPtr &Adapter = ContextImpl.getAdapter ();
879
876
880
877
std::string UserArgs = syclex::detail::userArgsAsString (BuildOptions);
881
878
@@ -904,7 +901,7 @@ class device_image_impl {
904
901
Properties.pMetadatas = nullptr ;
905
902
906
903
Adapter->call <UrApiKind::urProgramCreateWithBinary>(
907
- ContextImpl-> getHandleRef (), DeviceHandles.size (), DeviceHandles.data (),
904
+ ContextImpl. getHandleRef (), DeviceHandles.size (), DeviceHandles.data (),
908
905
Lengths.data (), Binaries.data (), &Properties, &UrProgram);
909
906
910
907
return true ;
@@ -1133,7 +1130,7 @@ class device_image_impl {
1133
1130
}
1134
1131
1135
1132
auto DGRegs = std::make_shared<ManagedDeviceGlobalsRegistry>(
1136
- getSyclObjImpl (MContext), std::string{Prefix},
1133
+ * getSyclObjImpl (MContext), std::string{Prefix},
1137
1134
std::move (DeviceGlobalNames), std::move (DeviceGlobalAllocations));
1138
1135
1139
1136
// Mark the image as input so the program manager will bring it into
@@ -1196,9 +1193,8 @@ class device_image_impl {
1196
1193
createProgramFromSource (const std::vector<device> Devices,
1197
1194
const std::vector<sycl::detail::string_view> &Options,
1198
1195
std::string *LogPtr) const {
1199
- const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
1200
- getSyclObjImpl (MContext);
1201
- const AdapterPtr &Adapter = ContextImpl->getAdapter ();
1196
+ sycl::detail::context_impl &ContextImpl = *getSyclObjImpl (MContext);
1197
+ const AdapterPtr &Adapter = ContextImpl.getAdapter ();
1202
1198
const auto spirv = [&]() -> std::vector<uint8_t > {
1203
1199
switch (MRTCBinInfo->MLanguage ) {
1204
1200
case syclex::source_language::opencl: {
@@ -1235,7 +1231,7 @@ class device_image_impl {
1235
1231
}();
1236
1232
1237
1233
ur_program_handle_t UrProgram = nullptr ;
1238
- Adapter->call <UrApiKind::urProgramCreateWithIL>(ContextImpl-> getHandleRef (),
1234
+ Adapter->call <UrApiKind::urProgramCreateWithIL>(ContextImpl. getHandleRef (),
1239
1235
spirv.data (), spirv.size (),
1240
1236
nullptr , &UrProgram);
1241
1237
// program created by urProgramCreateWithIL is implicitly retained.
0 commit comments