Skip to content

Commit 0dbcd3a

Browse files
[NFC][SYCL] Pass context_impl by raw ptr/ref in device_image_impl.hpp (#18981)
Part of the ongoing refactoring to prefer raw ptr/ref for SYCL RT objects by default with explicit `shared_from_this` when lifetimes need to be extended.
1 parent 9cf0dcc commit 0dbcd3a

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ constexpr uint8_t ImageOriginKernelCompiler = 1 << 2;
5858
class ManagedDeviceGlobalsRegistry {
5959
public:
6060
ManagedDeviceGlobalsRegistry(
61-
const std::shared_ptr<context_impl> &ContextImpl,
62-
const std::string &Prefix, std::vector<std::string> &&DeviceGlobalNames,
61+
context_impl &ContextImpl, const std::string &Prefix,
62+
std::vector<std::string> &&DeviceGlobalNames,
6363
std::vector<std::unique_ptr<std::byte[]>> &&DeviceGlobalAllocations)
64-
: MContextImpl{ContextImpl}, MPrefix{Prefix},
64+
: MContextImpl{ContextImpl.shared_from_this()}, MPrefix{Prefix},
6565
MDeviceGlobalNames{std::move(DeviceGlobalNames)},
6666
MDeviceGlobalAllocations{std::move(DeviceGlobalAllocations)} {}
6767

@@ -704,12 +704,11 @@ class device_image_impl {
704704
assert(MRTCBinInfo);
705705
assert(MOrigins & ImageOriginKernelCompiler);
706706

707-
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
708-
getSyclObjImpl(MContext);
707+
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
709708

710709
for (const auto &SyclDev : Devices) {
711710
device_impl &DevImpl = *getSyclObjImpl(SyclDev);
712-
if (!ContextImpl->hasDevice(DevImpl)) {
711+
if (!ContextImpl.hasDevice(DevImpl)) {
713712
throw sycl::exception(make_error_code(errc::invalid),
714713
"device not part of kernel_bundle context");
715714
}
@@ -742,7 +741,7 @@ class device_image_impl {
742741
Devices, BuildOptions, *SourceStrPtr, UrProgram);
743742
}
744743

745-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
744+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
746745

747746
if (!FetchedFromCache)
748747
UrProgram = createProgramFromSource(Devices, BuildOptions, LogPtr);
@@ -752,7 +751,7 @@ class device_image_impl {
752751
UrProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str());
753752
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
754753
Res = Adapter->call_nocheck<UrApiKind::urProgramBuild>(
755-
ContextImpl->getHandleRef(), UrProgram, XsFlags.c_str());
754+
ContextImpl.getHandleRef(), UrProgram, XsFlags.c_str());
756755
}
757756
Adapter->checkUrResult<errc::build>(Res);
758757

@@ -796,12 +795,11 @@ class device_image_impl {
796795
"compile is only available for kernel_bundle<bundle_state::source> "
797796
"when the source language was sycl.");
798797

799-
std::shared_ptr<sycl::detail::context_impl> ContextImpl =
800-
getSyclObjImpl(MContext);
798+
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
801799

802800
for (const auto &SyclDev : Devices) {
803801
detail::device_impl &DevImpl = *getSyclObjImpl(SyclDev);
804-
if (!ContextImpl->hasDevice(DevImpl)) {
802+
if (!ContextImpl.hasDevice(DevImpl)) {
805803
throw sycl::exception(make_error_code(errc::invalid),
806804
"device not part of kernel_bundle context");
807805
}
@@ -873,9 +871,8 @@ class device_image_impl {
873871
const std::vector<device> Devices,
874872
const std::vector<sycl::detail::string_view> &BuildOptions,
875873
const std::string &SourceStr, ur_program_handle_t &UrProgram) const {
876-
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
877-
getSyclObjImpl(MContext);
878-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
874+
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
875+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
879876

880877
std::string UserArgs = syclex::detail::userArgsAsString(BuildOptions);
881878

@@ -904,7 +901,7 @@ class device_image_impl {
904901
Properties.pMetadatas = nullptr;
905902

906903
Adapter->call<UrApiKind::urProgramCreateWithBinary>(
907-
ContextImpl->getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
904+
ContextImpl.getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
908905
Lengths.data(), Binaries.data(), &Properties, &UrProgram);
909906

910907
return true;
@@ -1133,7 +1130,7 @@ class device_image_impl {
11331130
}
11341131

11351132
auto DGRegs = std::make_shared<ManagedDeviceGlobalsRegistry>(
1136-
getSyclObjImpl(MContext), std::string{Prefix},
1133+
*getSyclObjImpl(MContext), std::string{Prefix},
11371134
std::move(DeviceGlobalNames), std::move(DeviceGlobalAllocations));
11381135

11391136
// Mark the image as input so the program manager will bring it into
@@ -1196,9 +1193,8 @@ class device_image_impl {
11961193
createProgramFromSource(const std::vector<device> Devices,
11971194
const std::vector<sycl::detail::string_view> &Options,
11981195
std::string *LogPtr) const {
1199-
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
1200-
getSyclObjImpl(MContext);
1201-
const AdapterPtr &Adapter = ContextImpl->getAdapter();
1196+
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
1197+
const AdapterPtr &Adapter = ContextImpl.getAdapter();
12021198
const auto spirv = [&]() -> std::vector<uint8_t> {
12031199
switch (MRTCBinInfo->MLanguage) {
12041200
case syclex::source_language::opencl: {
@@ -1235,7 +1231,7 @@ class device_image_impl {
12351231
}();
12361232

12371233
ur_program_handle_t UrProgram = nullptr;
1238-
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl->getHandleRef(),
1234+
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl.getHandleRef(),
12391235
spirv.data(), spirv.size(),
12401236
nullptr, &UrProgram);
12411237
// program created by urProgramCreateWithIL is implicitly retained.

0 commit comments

Comments
 (0)