Skip to content

[SYCL] Filter implicit kernel bundle images #5285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
7 changes: 6 additions & 1 deletion sycl/include/CL/sycl/kernel_bundle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,13 @@ template <typename KernelName> bool is_compatible(const device &Dev) {

namespace detail {

// TODO: This is no longer in use. Remove when ABI break is allowed.
__SYCL_EXPORT std::shared_ptr<detail::kernel_bundle_impl>
join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles);

__SYCL_EXPORT std::shared_ptr<detail::kernel_bundle_impl>
join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles,
bundle_state State);
}

/// \returns a new kernel bundle that represents the union of all the device
Expand All @@ -604,7 +609,7 @@ join(const std::vector<sycl::kernel_bundle<State>> &Bundles) {
KernelBundleImpls.push_back(detail::getSyclObjImpl(Bundle));

std::shared_ptr<detail::kernel_bundle_impl> Impl =
detail::join_impl(KernelBundleImpls);
detail::join_impl(KernelBundleImpls, State);
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
}

Expand Down
55 changes: 39 additions & 16 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class kernel_bundle_impl {

public:
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {

common_ctor_checks(State);

Expand All @@ -89,7 +89,7 @@ class kernel_bundle_impl {

// Interop constructor used by make_kernel
kernel_bundle_impl(context Ctx, std::vector<device> Devs)
: MContext(Ctx), MDevices(Devs) {
: MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
if (!checkAllDevicesAreInContext(Devs, Ctx))
throw sycl::exception(
make_error_code(errc::invalid),
Expand All @@ -111,7 +111,8 @@ class kernel_bundle_impl {
kernel_bundle_impl(const kernel_bundle<bundle_state::input> &InputBundle,
std::vector<device> Devs, const property_list &PropList,
bundle_state TargetState)
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)) {
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
MState(TargetState) {

MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();

Expand Down Expand Up @@ -161,7 +162,7 @@ class kernel_bundle_impl {
kernel_bundle_impl(
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
std::vector<device> Devs, const property_list &PropList)
: MDevices(std::move(Devs)) {
: MDevices(std::move(Devs)), MState(bundle_state::executable) {

if (MDevices.empty())
throw sycl::exception(make_error_code(errc::invalid),
Expand Down Expand Up @@ -241,7 +242,7 @@ class kernel_bundle_impl {
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
const std::vector<kernel_id> &KernelIDs,
bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {

// TODO: Add a check that all kernel ids are compatible with at least one
// device in Devs
Expand All @@ -253,7 +254,7 @@ class kernel_bundle_impl {

kernel_bundle_impl(context Ctx, std::vector<device> Devs,
const DevImgSelectorImpl &Selector, bundle_state State)
: MContext(std::move(Ctx)), MDevices(std::move(Devs)) {
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {

common_ctor_checks(State);

Expand All @@ -262,7 +263,9 @@ class kernel_bundle_impl {
}

// C'tor matches sycl::join API
kernel_bundle_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles) {
kernel_bundle_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles,
bundle_state State)
: MState(State) {
if (Bundles.empty())
return;

Expand Down Expand Up @@ -480,22 +483,41 @@ class kernel_bundle_impl {

size_t size() const noexcept { return MDeviceImages.size(); }

bundle_state get_bundle_state() const {
// Interop kernel-bundles are always in executable state
if (MIsInterop)
return bundle_state::executable;
// All device images are expected to have the same state
return MDeviceImages.empty()
? bundle_state::input
: detail::getSyclObjImpl(MDeviceImages[0])->get_state();
}
bundle_state get_bundle_state() const { return MState; }

const SpecConstMapT &get_spec_const_map_ref() const noexcept {
return MSpecConstValues;
}

bool isInterop() const { return MIsInterop; }

bool add_kernel(const kernel_id &KernelID, const device &Dev) {
// Skip if kernel is already there
if (has_kernel(KernelID, Dev))
return true;

// First try and get images in current bundle state
const bundle_state BundleState = get_bundle_state();
std::vector<device_image_plain> NewDevImgs =
detail::ProgramManager::getInstance().getSYCLDeviceImages(
MContext, {Dev}, {KernelID}, BundleState);

// No images found so we report as not inserted
if (NewDevImgs.empty())
return false;

// Propagate already set specialization constants to the new images
for (device_image_plain &DevImg : NewDevImgs)
for (auto SpecConst : MSpecConstValues)
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
SpecConst.first.c_str(), SpecConst.second.data());

// Add the images to the collection
MDeviceImages.insert(MDeviceImages.end(), NewDevImgs.begin(),
NewDevImgs.end());
return true;
}

private:
context MContext;
std::vector<device> MDevices;
Expand All @@ -504,6 +526,7 @@ class kernel_bundle_impl {
// from any device image.
SpecConstMapT MSpecConstValues;
bool MIsInterop = false;
bundle_state MState;
};

} // namespace detail
Expand Down
4 changes: 4 additions & 0 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,10 @@ std::vector<device_image_plain> ProgramManager::getSYCLDeviceImages(
std::vector<device_image_plain> ProgramManager::getSYCLDeviceImages(
const context &Ctx, const std::vector<device> &Devs,
const std::vector<kernel_id> &KernelIDs, bundle_state TargetState) {
// Fast path for when no kernel IDs are requested
if (KernelIDs.empty())
return {};

{
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);

Expand Down
41 changes: 35 additions & 6 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ handler::getOrInsertHandlerKernelBundle(bool Insert) const {

// No kernel bundle yet, create one
if (!KernelBundleImpPtr && Insert) {
KernelBundleImpPtr = detail::getSyclObjImpl(
get_kernel_bundle<bundle_state::input>(MQueue->get_context()));
if (KernelBundleImpPtr->empty()) {
KernelBundleImpPtr = detail::getSyclObjImpl(
get_kernel_bundle<bundle_state::executable>(MQueue->get_context()));
}
// Create an empty kernel bundle to add kernels to later
KernelBundleImpPtr =
detail::getSyclObjImpl(get_kernel_bundle<bundle_state::input>(
MQueue->get_context(), {MQueue->get_device()}, {}));

detail::ExtendedMemberT EMember = {
detail::ExtendedMembersType::HANDLER_KERNEL_BUNDLE, KernelBundleImpPtr};
Expand Down Expand Up @@ -169,6 +167,33 @@ event handler::finalize() {
// If there were uses of set_specialization_constant build the kernel_bundle
KernelBundleImpPtr = getOrInsertHandlerKernelBundle(/*Insert=*/false);
if (KernelBundleImpPtr) {
// Make sure implicit non-interop kernel bundles have the kernel
if (!KernelBundleImpPtr->isInterop() &&
!getHandlerImpl()->isStateExplicitKernelBundle()) {
kernel_id KernelID =
detail::ProgramManager::getInstance().getSYCLKernelID(MKernelName);
bool KernelInserted =
KernelBundleImpPtr->add_kernel(KernelID, MQueue->get_device());
// If kernel was not inserted and the bundle is in input mode we try
// building it and trying to find the kernel in executable mode
if (!KernelInserted &&
KernelBundleImpPtr->get_bundle_state() == bundle_state::input) {
auto KernelBundle =
detail::createSyclObjFromImpl<kernel_bundle<bundle_state::input>>(
KernelBundleImpPtr);
kernel_bundle<bundle_state::executable> ExecKernelBundle =
build(KernelBundle);
KernelBundleImpPtr = detail::getSyclObjImpl(ExecKernelBundle);
setHandlerKernelBundle(KernelBundleImpPtr);
KernelInserted =
KernelBundleImpPtr->add_kernel(KernelID, MQueue->get_device());
}
// If the kernel was not found in executable mode we throw an exception
if (!KernelInserted)
throw sycl::exception(make_error_code(errc::runtime),
"Failed to add kernel to kernel bundle.");
}

switch (KernelBundleImpPtr->get_bundle_state()) {
case bundle_state::input: {
// Underlying level expects kernel_bundle to be in executable state
Expand Down Expand Up @@ -618,6 +643,10 @@ void handler::verifyUsedKernelBundle(const std::string &KernelName) {
if (!UsedKernelBundleImplPtr)
return;

// Implicit kernel bundles are populated late so we ignore them
if (!getHandlerImpl()->isStateExplicitKernelBundle())
return;

kernel_id KernelID = detail::get_kernel_id_impl(KernelName);
device Dev = detail::getDeviceFromHandler(*this);
if (!UsedKernelBundleImplPtr->has_kernel(KernelID, Dev))
Expand Down
9 changes: 8 additions & 1 deletion sycl/source/kernel_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,14 @@ get_empty_interop_kernel_bundle_impl(const context &Ctx,

std::shared_ptr<detail::kernel_bundle_impl>
join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles) {
return std::make_shared<detail::kernel_bundle_impl>(Bundles);
return std::make_shared<detail::kernel_bundle_impl>(Bundles,
bundle_state::input);
}

std::shared_ptr<detail::kernel_bundle_impl>
join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles,
bundle_state State) {
return std::make_shared<detail::kernel_bundle_impl>(Bundles, State);
}

bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3903,6 +3903,7 @@ _ZN2cl4sycl6detail6OSUtil16getCurrentDSODirB5cxx11Ev
_ZN2cl4sycl6detail6OSUtil17getOSModuleHandleEPKv
_ZN2cl4sycl6detail6OSUtil7makeDirEPKc
_ZN2cl4sycl6detail9join_implERKSt6vectorISt10shared_ptrINS1_18kernel_bundle_implEESaIS5_EE
_ZN2cl4sycl6detail9join_implERKSt6vectorISt10shared_ptrINS1_18kernel_bundle_implEESaIS5_EENS0_12bundle_stateE
_ZN2cl4sycl6detail9link_implERKSt6vectorINS0_13kernel_bundleILNS0_12bundle_stateE1EEESaIS5_EERKS2_INS0_6deviceESaISA_EERKNS0_13property_listE
_ZN2cl4sycl6device11get_devicesENS0_4info11device_typeE
_ZN2cl4sycl6deviceC1EP13_cl_device_id
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -2398,6 +2398,7 @@
?is_in_order@queue@sycl@cl@@QEBA_NXZ
?is_specialization_constant_set@kernel_bundle_plain@detail@sycl@cl@@IEBA_NPEBD@Z
?join_impl@detail@sycl@cl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@AEBV?$vector@V?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@V?$allocator@V?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@@2@@5@@Z
?join_impl@detail@sycl@cl@@YA?AV?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@AEBV?$vector@V?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@V?$allocator@V?$shared_ptr@Vkernel_bundle_impl@detail@sycl@cl@@@std@@@2@@5@W4bundle_state@23@@Z
?ldexp@__host_std@cl@@YA?AV?$vec@M$00@sycl@2@V342@V?$vec@H$00@42@@Z
?ldexp@__host_std@cl@@YA?AV?$vec@M$01@sycl@2@V342@V?$vec@H$01@42@@Z
?ldexp@__host_std@cl@@YA?AV?$vec@M$02@sycl@2@V342@V?$vec@H$02@42@@Z
Expand Down