Skip to content

Commit 74c7854

Browse files
[SYCL] enable_shared_from_this for kernel_bundle_impl (#18899)
This allow us to get rid of passing self to kernel_bundle_impl methods and to return raw pointer to kernel bundle instead of shared_ptr.
1 parent 51aeea6 commit 74c7854

18 files changed

+235
-160
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ class __SYCL_EXPORT handler {
894894
// If the kernel lambda is callable with a kernel_handler argument, manifest
895895
// the associated kernel handler.
896896
if constexpr (IsCallableWithKernelHandler) {
897-
getOrInsertHandlerKernelBundle(/*Insert=*/true);
897+
getOrInsertHandlerKernelBundlePtr(/*Insert=*/true);
898898
}
899899
}
900900

@@ -1709,13 +1709,26 @@ class __SYCL_EXPORT handler {
17091709
void setStateSpecConstSet();
17101710
bool isStateExplicitKernelBundle() const;
17111711

1712+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
17121713
std::shared_ptr<detail::kernel_bundle_impl>
17131714
getOrInsertHandlerKernelBundle(bool Insert) const;
1715+
#endif
1716+
1717+
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1718+
// Rename to just getOrInsertHandlerKernelBundle
1719+
#endif
1720+
detail::kernel_bundle_impl *
1721+
getOrInsertHandlerKernelBundlePtr(bool Insert) const;
17141722

17151723
void setHandlerKernelBundle(kernel Kernel);
17161724

1725+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
17171726
void setHandlerKernelBundle(
17181727
const std::shared_ptr<detail::kernel_bundle_impl> &NewKernelBundleImpPtr);
1728+
#endif
1729+
1730+
template <typename SharedPtrT>
1731+
void setHandlerKernelBundle(SharedPtrT &&NewKernelBundleImpPtr);
17191732

17201733
void SetHostTask(std::function<void()> &&Func);
17211734
void SetHostTask(std::function<void(interop_handle)> &&Func);
@@ -1763,6 +1776,8 @@ class __SYCL_EXPORT handler {
17631776
/// called.
17641777
void setUserFacingNodeType(ext::oneapi::experimental::node_type Type);
17651778

1779+
kernel_bundle<bundle_state::input> getKernelBundle() const;
1780+
17661781
public:
17671782
handler(const handler &) = delete;
17681783
handler(handler &&) = delete;

sycl/include/sycl/kernel_bundle.hpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,12 +1330,7 @@ void handler::set_specialization_constant(
13301330

13311331
setStateSpecConstSet();
13321332

1333-
std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImplPtr =
1334-
getOrInsertHandlerKernelBundle(/*Insert=*/true);
1335-
1336-
detail::createSyclObjFromImpl<kernel_bundle<bundle_state::input>>(
1337-
std::move(KernelBundleImplPtr))
1338-
.set_specialization_constant<SpecName>(Value);
1333+
getKernelBundle().set_specialization_constant<SpecName>(Value);
13391334
}
13401335

13411336
template <auto &SpecName>
@@ -1347,12 +1342,7 @@ handler::get_specialization_constant() const {
13471342
"Specialization constants cannot be read after "
13481343
"explicitly setting the used kernel bundle");
13491344

1350-
std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImplPtr =
1351-
getOrInsertHandlerKernelBundle(/*Insert=*/true);
1352-
1353-
return detail::createSyclObjFromImpl<kernel_bundle<bundle_state::input>>(
1354-
std::move(KernelBundleImplPtr))
1355-
.get_specialization_constant<SpecName>();
1345+
return getKernelBundle().get_specialization_constant<SpecName>();
13561346
}
13571347

13581348
} // namespace _V1

sycl/source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ set(SYCL_COMMON_SOURCES
269269
"detail/host_pipe_map.cpp"
270270
"detail/device_global_map.cpp"
271271
"detail/device_global_map_entry.cpp"
272+
"detail/device_image_impl.cpp"
272273
"detail/device_impl.cpp"
273274
"detail/error_handling/error_handling.cpp"
274275
"detail/event_impl.cpp"

sycl/source/backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
306306
ImageOriginInterop);
307307
device_image_plain DevImg{DevImgImpl};
308308

309-
return std::make_shared<kernel_bundle_impl>(TargetContext, Devices, DevImg);
309+
return kernel_bundle_impl::create(TargetContext, Devices, DevImg);
310310
}
311311

312312
// TODO: Unused. Remove when allowed.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//==----------------- device_image_impl.cpp - SYCL device_image_impl -------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <detail/device_image_impl.hpp>
10+
#include <detail/kernel_bundle_impl.hpp>
11+
12+
namespace sycl {
13+
inline namespace _V1 {
14+
namespace detail {
15+
16+
std::shared_ptr<kernel_impl> device_image_impl::tryGetSourceBasedKernel(
17+
std::string_view Name, const context &Context,
18+
const kernel_bundle_impl &OwnerBundle,
19+
const std::shared_ptr<device_image_impl> &Self) const {
20+
if (!(getOriginMask() & ImageOriginKernelCompiler))
21+
return nullptr;
22+
23+
assert(MRTCBinInfo);
24+
std::string AdjustedName = adjustKernelName(Name);
25+
if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
26+
auto &PM = ProgramManager::getInstance();
27+
for (const std::string &Prefix : MRTCBinInfo->MPrefixes) {
28+
auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName);
29+
30+
if (!KID || !has_kernel(*KID))
31+
continue;
32+
33+
auto UrProgram = get_ur_program_ref();
34+
auto [UrKernel, CacheMutex, ArgMask] =
35+
PM.getOrCreateKernel(Context, AdjustedName,
36+
/*PropList=*/{}, UrProgram);
37+
return std::make_shared<kernel_impl>(UrKernel, *getSyclObjImpl(Context),
38+
Self, OwnerBundle.shared_from_this(),
39+
ArgMask, UrProgram, CacheMutex);
40+
}
41+
return nullptr;
42+
}
43+
44+
ur_program_handle_t UrProgram = get_ur_program_ref();
45+
const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter();
46+
ur_kernel_handle_t UrKernel = nullptr;
47+
Adapter->call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
48+
&UrKernel);
49+
// Kernel created by urKernelCreate is implicitly retained.
50+
51+
return std::make_shared<kernel_impl>(
52+
UrKernel, *detail::getSyclObjImpl(Context), Self,
53+
OwnerBundle.shared_from_this(), /*ArgMask=*/nullptr, UrProgram,
54+
/*CacheMutex=*/nullptr);
55+
}
56+
57+
} // namespace detail
58+
} // namespace _V1
59+
} // namespace sycl

sycl/source/detail/device_image_impl.hpp

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -617,45 +617,10 @@ class device_image_impl {
617617
MRTCBinInfo->MKernelNames.end();
618618
}
619619

620-
std::shared_ptr<kernel_impl> tryGetSourceBasedKernel(
621-
std::string_view Name, const context &Context,
622-
const std::shared_ptr<kernel_bundle_impl> &OwnerBundle,
623-
const std::shared_ptr<device_image_impl> &Self) const {
624-
if (!(getOriginMask() & ImageOriginKernelCompiler))
625-
return nullptr;
626-
627-
assert(MRTCBinInfo);
628-
std::string AdjustedName = adjustKernelName(Name);
629-
if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) {
630-
auto &PM = ProgramManager::getInstance();
631-
for (const std::string &Prefix : MRTCBinInfo->MPrefixes) {
632-
auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName);
633-
634-
if (!KID || !has_kernel(*KID))
635-
continue;
636-
637-
auto UrProgram = get_ur_program_ref();
638-
auto [UrKernel, CacheMutex, ArgMask] =
639-
PM.getOrCreateKernel(Context, AdjustedName,
640-
/*PropList=*/{}, UrProgram);
641-
return std::make_shared<kernel_impl>(UrKernel, *getSyclObjImpl(Context),
642-
Self, OwnerBundle, ArgMask,
643-
UrProgram, CacheMutex);
644-
}
645-
return nullptr;
646-
}
647-
648-
ur_program_handle_t UrProgram = get_ur_program_ref();
649-
const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter();
650-
ur_kernel_handle_t UrKernel = nullptr;
651-
Adapter->call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
652-
&UrKernel);
653-
// Kernel created by urKernelCreate is implicitly retained.
654-
655-
return std::make_shared<kernel_impl>(
656-
UrKernel, *detail::getSyclObjImpl(Context), Self, OwnerBundle,
657-
/*ArgMask=*/nullptr, UrProgram, /*CacheMutex=*/nullptr);
658-
}
620+
std::shared_ptr<kernel_impl>
621+
tryGetSourceBasedKernel(std::string_view Name, const context &Context,
622+
const kernel_bundle_impl &OwnerBundle,
623+
const std::shared_ptr<device_image_impl> &Self) const;
659624

660625
bool hasDeviceGlobalName(const std::string &Name) const noexcept {
661626
if (!MRTCBinInfo.has_value())

sycl/source/detail/graph_impl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
860860
std::tie(CmdTraceEvent, InstanceID) = emitKernelInstrumentationData(
861861
StreamID, CGExec->MSyclKernel, CodeLoc, CGExec->MIsTopCodeLoc,
862862
CGExec->MKernelName.data(), CGExec->MKernelNameBasedCachePtr, nullptr,
863-
CGExec->MNDRDesc, CGExec->MKernelBundle, CGExec->MArgs);
863+
CGExec->MNDRDesc, CGExec->MKernelBundle.get(), CGExec->MArgs);
864864
if (CmdTraceEvent)
865865
sycl::detail::emitInstrumentationGeneral(
866866
StreamID, InstanceID, CmdTraceEvent, xpti::trace_task_begin, nullptr);
@@ -1536,8 +1536,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
15361536
EliminatedArgMask = Kernel->getKernelArgMask();
15371537
} else if (auto SyclKernelImpl =
15381538
KernelBundleImplPtr
1539-
? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName,
1540-
KernelBundleImplPtr)
1539+
? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName)
15411540
: std::shared_ptr<kernel_impl>{nullptr}) {
15421541
UrKernel = SyclKernelImpl->getHandleRef();
15431542
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();

sycl/source/detail/helpers.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
7373
DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref();
7474
Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref();
7575
} else if (auto SyclKernelImpl =
76-
KernelBundleImpl ? KernelBundleImpl->tryGetKernel(
77-
KernelName, KernelBundleImpl)
76+
KernelBundleImpl ? KernelBundleImpl->tryGetKernel(KernelName)
7877
: std::shared_ptr<kernel_impl>{nullptr}) {
7978
// Retrieve the device image from the kernel bundle.
8079
DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref();

0 commit comments

Comments
 (0)