Skip to content

Commit 50e798f

Browse files
author
Andrew Lamzed-Short
authored
[SYCL] Implemented templated kernel_bundle::has_kernel functions (#8070)
Added missing template-based `has_kernel` functions in `kernel_bundle`. Tried to make implementation align with existing `kernel_bundle_plain`-based methods as possible.
1 parent e833806 commit 50e798f

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

sycl/include/sycl/kernel_bundle.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ namespace detail {
3636
class kernel_id_impl;
3737
}
3838

39+
template <typename KernelName> kernel_id get_kernel_id();
40+
3941
/// Objects of the class identify kernel is some kernel_bundle related APIs
4042
///
4143
/// \ingroup sycl_api
@@ -236,6 +238,19 @@ class kernel_bundle : public detail::kernel_bundle_plain,
236238
return kernel_bundle_plain::has_kernel(KernelID, Dev);
237239
}
238240

241+
/// \returns true only if the kernel bundle contains the kernel identified by
242+
/// KernelName.
243+
template <typename KernelName> bool has_kernel() const noexcept {
244+
return has_kernel(get_kernel_id<KernelName>());
245+
}
246+
247+
/// \returns true only if the kernel bundle contains the kernel identified by
248+
/// KernelName and if that kernel is compatible with the device Dev.
249+
template <typename KernelName>
250+
bool has_kernel(const device &Dev) const noexcept {
251+
return has_kernel(get_kernel_id<KernelName>(), Dev);
252+
}
253+
239254
/// \returns a vector of kernel_id's that contained in the kernel_bundle
240255
std::vector<kernel_id> get_kernel_ids() const {
241256
return kernel_bundle_plain::get_kernel_ids();
@@ -355,6 +370,8 @@ __SYCL_EXPORT kernel_id get_kernel_id_impl(std::string KernelName);
355370

356371
/// \returns the kernel_id associated with the KernelName
357372
template <typename KernelName> kernel_id get_kernel_id() {
373+
// FIXME: This must fail at link-time if KernelName not in any available
374+
// translation units.
358375
using KI = sycl::detail::KernelInfo<KernelName>;
359376
return detail::get_kernel_id_impl(KI::getName());
360377
}

sycl/unittests/SYCL2020/KernelID.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,26 @@ TEST(KernelID, KernelIDHasKernel) {
264264
EXPECT_TRUE(InputBundle7.has_kernel(TestKernel3ID));
265265
}
266266

267-
TEST(KernelID, InvalidKernelName) {
267+
TEST(KernelID, HasKernelTemplated) {
268+
sycl::unittest::PiMock Mock;
269+
sycl::platform Plt = Mock.getPlatform();
270+
271+
const sycl::device Dev = Plt.get_devices()[0];
272+
sycl::context Ctx{Dev};
273+
sycl::queue Queue{Ctx, Dev};
274+
275+
sycl::kernel_id TestKernel1ID = sycl::get_kernel_id<TestKernel1>();
276+
277+
std::vector<sycl::kernel_id> KernelIDs1 = {TestKernel1ID};
278+
auto InputBundle1 =
279+
sycl::get_kernel_bundle<sycl::bundle_state::input>(Ctx, KernelIDs1);
280+
281+
EXPECT_TRUE(InputBundle1.has_kernel<TestKernel1>());
282+
EXPECT_FALSE(InputBundle1.has_kernel<TestKernel2>());
283+
EXPECT_TRUE(InputBundle1.has_kernel<TestKernel3>());
284+
}
285+
286+
TEST(KernelID, GetKernelIDInvalidKernelName) {
268287
sycl::unittest::PiMock Mock;
269288
sycl::platform Plt = Mock.getPlatform();
270289

0 commit comments

Comments
 (0)