Skip to content

Commit 61e5101

Browse files
[SYCL] Filter out unneeded device images with lower state than requested (#8523)
When fetching device images compatible with non-input states, we can ignore an image if another one with a higher state is available for all the possible kernel-device pairs. This patch adds the logic for filtering out such unnecessary images so that we can avoid JIT compilation if both AOT and SPIRV images are present.
1 parent 3be2e42 commit 61e5101

File tree

7 files changed

+354
-51
lines changed

7 files changed

+354
-51
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 102 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,46 +1683,120 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
16831683
}
16841684
assert(BinImages.size() > 0 && "Expected to find at least one device image");
16851685

1686+
// Ignore images with incompatible state. Image is considered compatible
1687+
// with a target state if an image is already in the target state or can
1688+
// be brought to target state by compiling/linking/building.
1689+
//
1690+
// Example: an image in "executable" state is not compatible with
1691+
// "input" target state - there is no operation to convert the image it
1692+
// to "input" state. An image in "input" state is compatible with
1693+
// "executable" target state because it can be built to get into
1694+
// "executable" state.
1695+
for (auto It = BinImages.begin(); It != BinImages.end();) {
1696+
if (getBinImageState(*It) > TargetState)
1697+
It = BinImages.erase(It);
1698+
else
1699+
++It;
1700+
}
1701+
16861702
std::vector<device_image_plain> SYCLDeviceImages;
1687-
for (RTDeviceBinaryImage *BinImage : BinImages) {
1688-
const bundle_state ImgState = getBinImageState(BinImage);
1689-
1690-
// Ignore images with incompatible state. Image is considered compatible
1691-
// with a target state if an image is already in the target state or can
1692-
// be brought to target state by compiling/linking/building.
1693-
//
1694-
// Example: an image in "executable" state is not compatible with
1695-
// "input" target state - there is no operation to convert the image it
1696-
// to "input" state. An image in "input" state is compatible with
1697-
// "executable" target state because it can be built to get into
1698-
// "executable" state.
1699-
if (ImgState > TargetState)
1700-
continue;
17011703

1702-
for (const sycl::device &Dev : Devs) {
1704+
// If a non-input state is requested, we can filter out some compatible
1705+
// images and return only those with the highest compatible state for each
1706+
// device-kernel pair. This map tracks how many kernel-device pairs need each
1707+
// image, so that any unneeded ones are skipped.
1708+
// TODO this has no effect if the requested state is input, consider having
1709+
// a separate branch for that case to avoid unnecessary tracking work.
1710+
struct DeviceBinaryImageInfo {
1711+
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
1712+
bundle_state State = bundle_state::input;
1713+
int RequirementCounter = 0;
1714+
};
1715+
std::unordered_map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;
1716+
1717+
for (const sycl::device &Dev : Devs) {
1718+
// Track the highest image state for each requested kernel.
1719+
using StateImagesPairT =
1720+
std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
1721+
using KernelImageMapT =
1722+
std::map<kernel_id, StateImagesPairT, LessByNameComp>;
1723+
KernelImageMapT KernelImageMap;
1724+
if (!KernelIDs.empty())
1725+
for (const kernel_id &KernelID : KernelIDs)
1726+
KernelImageMap.insert({KernelID, {}});
1727+
1728+
for (RTDeviceBinaryImage *BinImage : BinImages) {
17031729
if (!compatibleWithDevice(BinImage, Dev) ||
17041730
!doesDevSupportDeviceRequirements(Dev, *BinImage))
17051731
continue;
17061732

1707-
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
1708-
// Collect kernel names for the image
1709-
{
1710-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1711-
KernelIDs = m_BinImg2KernelIDs[BinImage];
1712-
// If the image does not contain any non-service kernels we can skip it.
1713-
if (!KernelIDs || KernelIDs->empty())
1714-
continue;
1733+
auto InsertRes = ImageInfoMap.insert({BinImage, {}});
1734+
DeviceBinaryImageInfo &ImgInfo = InsertRes.first->second;
1735+
if (InsertRes.second) {
1736+
ImgInfo.State = getBinImageState(BinImage);
1737+
// Collect kernel names for the image
1738+
{
1739+
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1740+
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
1741+
}
17151742
}
1743+
const bundle_state ImgState = ImgInfo.State;
1744+
const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
1745+
ImgInfo.KernelIDs;
1746+
int &ImgRequirementCounter = ImgInfo.RequirementCounter;
17161747

1717-
DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
1718-
BinImage, Ctx, Devs, ImgState, KernelIDs, /*PIProgram=*/nullptr);
1748+
// If the image does not contain any non-service kernels we can skip it.
1749+
if (!ImageKernelIDs || ImageKernelIDs->empty())
1750+
continue;
17191751

1720-
SYCLDeviceImages.push_back(
1721-
createSyclObjFromImpl<device_image_plain>(Impl));
1722-
break;
1752+
// Update tracked information.
1753+
for (kernel_id &KernelID : *ImageKernelIDs) {
1754+
StateImagesPairT *StateImagesPair;
1755+
// If only specific kernels are requested, ignore the rest.
1756+
if (!KernelIDs.empty()) {
1757+
auto It = KernelImageMap.find(KernelID);
1758+
if (It == KernelImageMap.end())
1759+
continue;
1760+
StateImagesPair = &It->second;
1761+
} else
1762+
StateImagesPair = &KernelImageMap[KernelID];
1763+
1764+
auto &[KernelImagesState, KernelImages] = *StateImagesPair;
1765+
1766+
if (KernelImages.empty()) {
1767+
KernelImagesState = ImgState;
1768+
KernelImages.push_back(BinImage);
1769+
++ImgRequirementCounter;
1770+
} else if (KernelImagesState < ImgState) {
1771+
for (RTDeviceBinaryImage *Img : KernelImages) {
1772+
auto It = ImageInfoMap.find(Img);
1773+
assert(It != ImageInfoMap.end());
1774+
assert(It->second.RequirementCounter > 0);
1775+
--(It->second.RequirementCounter);
1776+
}
1777+
KernelImages.clear();
1778+
KernelImages.push_back(BinImage);
1779+
KernelImagesState = ImgState;
1780+
++ImgRequirementCounter;
1781+
} else if (KernelImagesState == ImgState) {
1782+
KernelImages.push_back(BinImage);
1783+
++ImgRequirementCounter;
1784+
}
1785+
}
17231786
}
17241787
}
17251788

1789+
for (const auto &ImgInfoPair : ImageInfoMap) {
1790+
if (ImgInfoPair.second.RequirementCounter == 0)
1791+
continue;
1792+
1793+
DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
1794+
ImgInfoPair.first, Ctx, Devs, ImgInfoPair.second.State,
1795+
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr);
1796+
1797+
SYCLDeviceImages.push_back(createSyclObjFromImpl<device_image_plain>(Impl));
1798+
}
1799+
17261800
return SYCLDeviceImages;
17271801
}
17281802

sycl/unittests/SYCL2020/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_sycl_unittest(SYCL2020Tests OBJECT
44
GetNativeOpenCL.cpp
55
SpecializationConstant.cpp
66
KernelBundle.cpp
7+
KernelBundleStateFiltering.cpp
78
KernelID.cpp
89
HasExtension.cpp
910
IsCompatible.cpp
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
//==---- KernelBundleStateFiltering.cpp --- Kernel bundle unit test --------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <detail/device_impl.hpp>
10+
#include <detail/kernel_bundle_impl.hpp>
11+
#include <sycl/sycl.hpp>
12+
13+
#include <helpers/MockKernelInfo.hpp>
14+
#include <helpers/PiImage.hpp>
15+
#include <helpers/PiMock.hpp>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <algorithm>
20+
#include <set>
21+
#include <vector>
22+
23+
class KernelA;
24+
class KernelB;
25+
class KernelC;
26+
class KernelD;
27+
class KernelE;
28+
namespace sycl {
29+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
30+
namespace detail {
31+
template <> struct KernelInfo<KernelA> : public unittest::MockKernelInfoBase {
32+
static constexpr const char *getName() { return "KernelA"; }
33+
};
34+
template <> struct KernelInfo<KernelB> : public unittest::MockKernelInfoBase {
35+
static constexpr const char *getName() { return "KernelB"; }
36+
};
37+
template <> struct KernelInfo<KernelC> : public unittest::MockKernelInfoBase {
38+
static constexpr const char *getName() { return "KernelC"; }
39+
};
40+
template <> struct KernelInfo<KernelD> : public unittest::MockKernelInfoBase {
41+
static constexpr const char *getName() { return "KernelD"; }
42+
};
43+
template <> struct KernelInfo<KernelE> : public unittest::MockKernelInfoBase {
44+
static constexpr const char *getName() { return "KernelE"; }
45+
};
46+
} // namespace detail
47+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
48+
} // namespace sycl
49+
50+
namespace {
51+
52+
std::set<const void *> TrackedImages;
53+
sycl::unittest::PiImage
54+
generateDefaultImage(std::initializer_list<std::string> KernelNames,
55+
pi_device_binary_type BinaryType,
56+
const char *DeviceTargetSpec) {
57+
using namespace sycl::unittest;
58+
59+
PiPropertySet PropSet;
60+
61+
static unsigned char NImage = 0;
62+
std::vector<unsigned char> Bin{NImage++};
63+
64+
PiArray<PiOffloadEntry> Entries = makeEmptyKernels(KernelNames);
65+
66+
PiImage Img{BinaryType, // Format
67+
DeviceTargetSpec,
68+
"", // Compile options
69+
"", // Link options
70+
std::move(Bin),
71+
std::move(Entries),
72+
std::move(PropSet)};
73+
const void *BinaryPtr = Img.getBinaryPtr();
74+
TrackedImages.insert(BinaryPtr);
75+
76+
return Img;
77+
}
78+
79+
// Image 0: input, KernelA KernelB
80+
// Image 1: exe, KernelA
81+
// Image 2: input, KernelC
82+
// Image 3: exe, KernelC
83+
// Image 4: input, KernelD
84+
// Image 5: input, KernelE
85+
// Image 6: exe, KernelE
86+
// Image 7: exe. KernelE
87+
sycl::unittest::PiImage Imgs[] = {
88+
generateDefaultImage({"KernelA", "KernelB"}, PI_DEVICE_BINARY_TYPE_SPIRV,
89+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
90+
generateDefaultImage({"KernelA"}, PI_DEVICE_BINARY_TYPE_NATIVE,
91+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
92+
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_SPIRV,
93+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
94+
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_NATIVE,
95+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
96+
generateDefaultImage({"KernelD"}, PI_DEVICE_BINARY_TYPE_SPIRV,
97+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
98+
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_SPIRV,
99+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
100+
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_NATIVE,
101+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
102+
generateDefaultImage({"KernelE"}, PI_DEVICE_BINARY_TYPE_NATIVE,
103+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64)};
104+
105+
sycl::unittest::PiImageArray<std::size(Imgs)> ImgArray{Imgs};
106+
std::vector<unsigned char> UsedImageIndices;
107+
108+
void redefinedPiProgramCreateCommon(const void *bin) {
109+
if (TrackedImages.count(bin) != 0) {
110+
unsigned char ImgIdx = *reinterpret_cast<const unsigned char *>(bin);
111+
UsedImageIndices.push_back(ImgIdx);
112+
}
113+
}
114+
115+
pi_result redefinedPiProgramCreate(pi_context context, const void *il,
116+
size_t length, pi_program *res_program) {
117+
redefinedPiProgramCreateCommon(il);
118+
return PI_SUCCESS;
119+
}
120+
121+
pi_result redefinedPiProgramCreateWithBinary(
122+
pi_context context, pi_uint32 num_devices, const pi_device *device_list,
123+
const size_t *lengths, const unsigned char **binaries,
124+
size_t num_metadata_entries, const pi_device_binary_property *metadata,
125+
pi_int32 *binary_status, pi_program *ret_program) {
126+
redefinedPiProgramCreateCommon(binaries[0]);
127+
return PI_SUCCESS;
128+
}
129+
130+
pi_result redefinedDevicesGet(pi_platform platform, pi_device_type device_type,
131+
pi_uint32 num_entries, pi_device *devices,
132+
pi_uint32 *num_devices) {
133+
if (num_devices)
134+
*num_devices = 2;
135+
136+
if (devices) {
137+
devices[0] = reinterpret_cast<pi_device>(1);
138+
devices[1] = reinterpret_cast<pi_device>(2);
139+
}
140+
141+
return PI_SUCCESS;
142+
}
143+
144+
pi_result redefinedExtDeviceSelectBinary(pi_device device,
145+
pi_device_binary *binaries,
146+
pi_uint32 num_binaries,
147+
pi_uint32 *selected_binary_ind) {
148+
EXPECT_EQ(num_binaries, 1U);
149+
// Treat image 3 as incompatible with one of the devices.
150+
if (TrackedImages.count(binaries[0]->BinaryStart) != 0 &&
151+
*binaries[0]->BinaryStart == 3 &&
152+
device == reinterpret_cast<pi_device>(2)) {
153+
return PI_ERROR_INVALID_BINARY;
154+
}
155+
*selected_binary_ind = 0;
156+
return PI_SUCCESS;
157+
}
158+
159+
void verifyImageUse(const std::vector<unsigned char> &ExpectedImages) {
160+
std::sort(UsedImageIndices.begin(), UsedImageIndices.end());
161+
EXPECT_TRUE(std::is_sorted(ExpectedImages.begin(), ExpectedImages.end()));
162+
EXPECT_EQ(UsedImageIndices, ExpectedImages);
163+
UsedImageIndices.clear();
164+
}
165+
166+
TEST(KernelBundle, DeviceImageStateFiltering) {
167+
sycl::unittest::PiMock Mock;
168+
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreate>(
169+
redefinedPiProgramCreate);
170+
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreateWithBinary>(
171+
redefinedPiProgramCreateWithBinary);
172+
173+
// No kernel ids specified.
174+
{
175+
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
176+
sycl::context Ctx{Dev};
177+
178+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
179+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(Ctx, {Dev});
180+
verifyImageUse({0, 1, 3, 4, 6, 7});
181+
}
182+
183+
sycl::kernel_id KernelAID = sycl::get_kernel_id<KernelA>();
184+
sycl::kernel_id KernelCID = sycl::get_kernel_id<KernelC>();
185+
sycl::kernel_id KernelDID = sycl::get_kernel_id<KernelD>();
186+
187+
// Request specific kernel ids.
188+
{
189+
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
190+
sycl::context Ctx{Dev};
191+
192+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
193+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
194+
Ctx, {Dev}, {KernelAID, KernelCID, KernelDID});
195+
verifyImageUse({1, 3, 4});
196+
}
197+
198+
// Check the case where some executable images are unsupported by one of
199+
// the devices.
200+
{
201+
Mock.redefine<sycl::detail::PiApiKind::piDevicesGet>(redefinedDevicesGet);
202+
Mock.redefine<sycl::detail::PiApiKind::piextDeviceSelectBinary>(
203+
redefinedExtDeviceSelectBinary);
204+
const std::vector<sycl::device> Devs = Mock.getPlatform().get_devices();
205+
sycl::context Ctx{Devs};
206+
207+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
208+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
209+
Ctx, Devs, {KernelAID, KernelCID, KernelDID});
210+
verifyImageUse({1, 2, 3, 4});
211+
}
212+
}
213+
} // namespace
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <sycl/detail/kernel_desc.hpp>
10+
11+
namespace sycl {
12+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
13+
namespace unittest {
14+
struct MockKernelInfoBase {
15+
static constexpr unsigned getNumParams() { return 0; }
16+
static const detail::kernel_param_desc_t &getParamDesc(int) {
17+
static detail::kernel_param_desc_t Dummy;
18+
return Dummy;
19+
}
20+
static constexpr bool isESIMD() { return false; }
21+
static constexpr bool callsThisItem() { return false; }
22+
static constexpr bool callsAnyThisFreeFunction() { return false; }
23+
static constexpr int64_t getKernelSize() { return 1; }
24+
};
25+
26+
} // namespace unittest
27+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
28+
} // namespace sycl

0 commit comments

Comments
 (0)