Skip to content

Commit db8fda2

Browse files
[SYCL][NFCI] More refactoring around "kernel wrapping" (#18015)
Follows what I started in #17838.
1 parent 085d07e commit db8fda2

File tree

1 file changed

+78
-194
lines changed

1 file changed

+78
-194
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 78 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,10 +1268,6 @@ class __SYCL_EXPORT handler {
12681268
typename PropertiesT>
12691269
void parallel_for_impl(nd_range<Dims> ExecutionRange, PropertiesT Props,
12701270
const KernelType &KernelFunc) {
1271-
// TODO: Properties may change the kernel function, so in order to avoid
1272-
// conflicts they should be included in the name.
1273-
using NameT =
1274-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
12751271
using LambdaArgType =
12761272
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
12771273
static_assert(
@@ -1280,21 +1276,8 @@ class __SYCL_EXPORT handler {
12801276
"must be either sycl::nd_item or be convertible from sycl::nd_item");
12811277
using TransformedArgType = sycl::nd_item<Dims>;
12821278

1283-
(void)ExecutionRange;
1284-
(void)Props;
1285-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
1286-
PropertiesT>::wrap(this, KernelFunc);
1287-
#ifndef __SYCL_DEVICE_ONLY__
1288-
throwIfActionIsCreated();
1289-
verifyUsedKernelBundleInternal(
1290-
detail::string_view{detail::getKernelName<NameT>()});
1291-
detail::checkValueRange<Dims>(ExecutionRange);
1292-
setNDRangeDescriptor(std::move(ExecutionRange));
1293-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1294-
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
1295-
std::move(KernelFunc));
1296-
setType(detail::CGType::Kernel);
1297-
#endif
1279+
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
1280+
KernelFunc, nullptr /*Kernel*/, Props, ExecutionRange);
12981281
}
12991282

13001283
/// Defines and invokes a SYCL kernel function for the specified range.
@@ -1362,26 +1345,12 @@ class __SYCL_EXPORT handler {
13621345
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
13631346
PropertiesT Props,
13641347
const KernelType &KernelFunc) {
1365-
// TODO: Properties may change the kernel function, so in order to avoid
1366-
// conflicts they should be included in the name.
1367-
using NameT =
1368-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
13691348
using LambdaArgType =
13701349
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1371-
(void)NumWorkGroups;
1372-
(void)Props;
1373-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1374-
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
1375-
#ifndef __SYCL_DEVICE_ONLY__
1376-
throwIfActionIsCreated();
1377-
verifyUsedKernelBundleInternal(
1378-
detail::string_view{detail::getKernelName<NameT>()});
1379-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1380-
detail::checkValueRange<Dims>(NumWorkGroups);
1381-
setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true);
1382-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
1383-
setType(detail::CGType::Kernel);
1384-
#endif // __SYCL_DEVICE_ONLY__
1350+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1351+
Dims,
1352+
/*SetNumWorkGroups=*/true>(KernelFunc, nullptr /*Kernel*/,
1353+
Props, NumWorkGroups);
13851354
}
13861355

13871356
/// Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -1403,29 +1372,12 @@ class __SYCL_EXPORT handler {
14031372
range<Dims> WorkGroupSize,
14041373
PropertiesT Props,
14051374
const KernelType &KernelFunc) {
1406-
// TODO: Properties may change the kernel function, so in order to avoid
1407-
// conflicts they should be included in the name.
1408-
using NameT =
1409-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
14101375
using LambdaArgType =
14111376
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1412-
(void)NumWorkGroups;
1413-
(void)WorkGroupSize;
1414-
(void)Props;
1415-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1416-
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
1417-
#ifndef __SYCL_DEVICE_ONLY__
1418-
throwIfActionIsCreated();
1419-
verifyUsedKernelBundleInternal(
1420-
detail::string_view{detail::getKernelName<NameT>()});
1421-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
14221377
nd_range<Dims> ExecRange =
14231378
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
1424-
detail::checkValueRange<Dims>(ExecRange);
1425-
setNDRangeDescriptor(std::move(ExecRange));
1426-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
1427-
setType(detail::CGType::Kernel);
1428-
#endif // __SYCL_DEVICE_ONLY__
1379+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1380+
Dims>(KernelFunc, nullptr /*Kernel*/, Props, ExecRange);
14291381
}
14301382

14311383
#ifdef SYCL_LANGUAGE_VERSION
@@ -1637,6 +1589,59 @@ class __SYCL_EXPORT handler {
16371589
}
16381590
};
16391591

1592+
template <
1593+
WrapAs WrapAsVal, typename KernelName, typename ElementType = void,
1594+
int Dims = 1, bool SetNumWorkGroups = false,
1595+
typename PropertiesT = ext::oneapi::experimental::empty_properties_t,
1596+
typename KernelType, typename MaybeKernelTy, typename... RangeParams>
1597+
void wrap_kernel(const KernelType &KernelFunc, MaybeKernelTy &&MaybeKernel,
1598+
const PropertiesT &Props,
1599+
[[maybe_unused]] RangeParams &&...params) {
1600+
// TODO: Properties may change the kernel function, so in order to avoid
1601+
// conflicts they should be included in the name.
1602+
using NameT =
1603+
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
1604+
(void)Props;
1605+
(void)MaybeKernel;
1606+
static_assert(std::is_same_v<MaybeKernelTy, kernel> ||
1607+
std::is_same_v<MaybeKernelTy, std::nullptr_t>);
1608+
KernelWrapper<WrapAsVal, NameT, KernelType, ElementType, PropertiesT>::wrap(
1609+
this, KernelFunc);
1610+
#ifndef __SYCL_DEVICE_ONLY__
1611+
throwIfActionIsCreated();
1612+
if constexpr (std::is_same_v<MaybeKernelTy, kernel>) {
1613+
// Ignore any set kernel bundles and use the one associated with the
1614+
// kernel.
1615+
setHandlerKernelBundle(MaybeKernel);
1616+
}
1617+
verifyUsedKernelBundleInternal(
1618+
detail::string_view{detail::getKernelName<NameT>()});
1619+
setType(detail::CGType::Kernel);
1620+
1621+
detail::checkValueRange<Dims>(params...);
1622+
if constexpr (SetNumWorkGroups) {
1623+
setNDRangeDescriptor(std::move(params)...,
1624+
/*SetNumWorkGroups=*/true);
1625+
} else {
1626+
setNDRangeDescriptor(std::move(params)...);
1627+
}
1628+
1629+
if constexpr (std::is_same_v<MaybeKernelTy, std::nullptr_t>) {
1630+
StoreLambda<NameT, KernelType, Dims, ElementType>(std::move(KernelFunc));
1631+
} else {
1632+
MKernel = detail::getSyclObjImpl(std::move(MaybeKernel));
1633+
if (!lambdaAndKernelHaveEqualName<NameT>()) {
1634+
extractArgsAndReqs();
1635+
MKernelName = getKernelName();
1636+
} else {
1637+
StoreLambda<NameT, KernelType, Dims, ElementType>(
1638+
std::move(KernelFunc));
1639+
}
1640+
}
1641+
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1642+
#endif
1643+
}
1644+
16401645
// NOTE: to support kernel_handler argument in kernel lambdas, only
16411646
// KernelWrapper<...>::wrap() must be called in this code.
16421647

@@ -1652,25 +1657,10 @@ class __SYCL_EXPORT handler {
16521657
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
16531658
void single_task_lambda_impl(PropertiesT Props,
16541659
const KernelType &KernelFunc) {
1655-
(void)Props;
1656-
// TODO: Properties may change the kernel function, so in order to avoid
1657-
// conflicts they should be included in the name.
1658-
using NameT =
1659-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
1660-
1661-
KernelWrapper<WrapAs::single_task, NameT, KernelType, void,
1662-
PropertiesT>::wrap(this, KernelFunc);
1660+
wrap_kernel<WrapAs::single_task, KernelName>(KernelFunc, nullptr /*Kernel*/,
1661+
Props, range<1>{1});
16631662
#ifndef __SYCL_DEVICE_ONLY__
1664-
throwIfActionIsCreated();
16651663
throwOnKernelParameterMisuse<KernelName, KernelType>();
1666-
verifyUsedKernelBundleInternal(
1667-
detail::string_view{detail::getKernelName<NameT>()});
1668-
// No need to check if range is out of INT_MAX limits as it's compile-time
1669-
// known constant.
1670-
setNDRangeDescriptor(range<1>{1});
1671-
processProperties<detail::isKernelESIMD<NameT>(), PropertiesT>(Props);
1672-
StoreLambda<NameT, KernelType, /*Dims*/ 1, void>(KernelFunc);
1673-
setType(detail::CGType::Kernel);
16741664
#endif
16751665
}
16761666

@@ -1954,26 +1944,13 @@ class __SYCL_EXPORT handler {
19541944
__SYCL2020_DEPRECATED("offsets are deprecated in SYCL2020")
19551945
void parallel_for(range<Dims> NumWorkItems, id<Dims> WorkItemOffset,
19561946
const KernelType &KernelFunc) {
1957-
using NameT =
1958-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
19591947
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
19601948
using TransformedArgType = std::conditional_t<
19611949
std::is_integral<LambdaArgType>::value && Dims == 1, item<Dims>,
19621950
typename TransformUserItemType<Dims, LambdaArgType>::type>;
1963-
(void)NumWorkItems;
1964-
(void)WorkItemOffset;
1965-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType,
1966-
TransformedArgType>::wrap(this, KernelFunc);
1967-
#ifndef __SYCL_DEVICE_ONLY__
1968-
throwIfActionIsCreated();
1969-
verifyUsedKernelBundleInternal(
1970-
detail::string_view{detail::getKernelName<NameT>()});
1971-
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
1972-
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
1973-
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
1974-
std::move(KernelFunc));
1975-
setType(detail::CGType::Kernel);
1976-
#endif
1951+
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
1952+
KernelFunc, nullptr /*Kernel*/, {} /*Props*/, NumWorkItems,
1953+
WorkItemOffset);
19771954
}
19781955

19791956
/// Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -2134,28 +2111,9 @@ class __SYCL_EXPORT handler {
21342111
const KernelType &KernelFunc) {
21352112
// Ignore any set kernel bundles and use the one associated with the kernel
21362113
setHandlerKernelBundle(Kernel);
2137-
using NameT =
2138-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
21392114
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
2140-
(void)Kernel;
2141-
(void)NumWorkItems;
2142-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2143-
this, KernelFunc);
2144-
#ifndef __SYCL_DEVICE_ONLY__
2145-
throwIfActionIsCreated();
2146-
verifyUsedKernelBundleInternal(
2147-
detail::string_view{detail::getKernelName<NameT>()});
2148-
detail::checkValueRange<Dims>(NumWorkItems);
2149-
setNDRangeDescriptor(std::move(NumWorkItems));
2150-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2151-
setType(detail::CGType::Kernel);
2152-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2153-
extractArgsAndReqs();
2154-
MKernelName = getKernelName();
2155-
} else
2156-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2157-
std::move(KernelFunc));
2158-
#endif
2115+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2116+
KernelFunc, Kernel, {} /*Props*/, NumWorkItems);
21592117
}
21602118

21612119
/// Defines and invokes a SYCL kernel function for the specified range and
@@ -2172,31 +2130,9 @@ class __SYCL_EXPORT handler {
21722130
__SYCL2020_DEPRECATED("offsets are deprecated in SYCL 2020")
21732131
void parallel_for(kernel Kernel, range<Dims> NumWorkItems,
21742132
id<Dims> WorkItemOffset, const KernelType &KernelFunc) {
2175-
using NameT =
2176-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
21772133
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
2178-
(void)Kernel;
2179-
(void)NumWorkItems;
2180-
(void)WorkItemOffset;
2181-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2182-
this, KernelFunc);
2183-
#ifndef __SYCL_DEVICE_ONLY__
2184-
throwIfActionIsCreated();
2185-
// Ignore any set kernel bundles and use the one associated with the kernel
2186-
setHandlerKernelBundle(Kernel);
2187-
verifyUsedKernelBundleInternal(
2188-
detail::string_view{detail::getKernelName<NameT>()});
2189-
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
2190-
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
2191-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2192-
setType(detail::CGType::Kernel);
2193-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2194-
extractArgsAndReqs();
2195-
MKernelName = getKernelName();
2196-
} else
2197-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2198-
std::move(KernelFunc));
2199-
#endif
2134+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2135+
KernelFunc, Kernel, {} /*Props*/, NumWorkItems, WorkItemOffset);
22002136
}
22012137

22022138
/// Defines and invokes a SYCL kernel function for the specified range and
@@ -2212,31 +2148,10 @@ class __SYCL_EXPORT handler {
22122148
int Dims>
22132149
void parallel_for(kernel Kernel, nd_range<Dims> NDRange,
22142150
const KernelType &KernelFunc) {
2215-
using NameT =
2216-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
22172151
using LambdaArgType =
22182152
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
2219-
(void)Kernel;
2220-
(void)NDRange;
2221-
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2222-
this, KernelFunc);
2223-
#ifndef __SYCL_DEVICE_ONLY__
2224-
throwIfActionIsCreated();
2225-
// Ignore any set kernel bundles and use the one associated with the kernel
2226-
setHandlerKernelBundle(Kernel);
2227-
verifyUsedKernelBundleInternal(
2228-
detail::string_view{detail::getKernelName<NameT>()});
2229-
detail::checkValueRange<Dims>(NDRange);
2230-
setNDRangeDescriptor(std::move(NDRange));
2231-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2232-
setType(detail::CGType::Kernel);
2233-
if (!lambdaAndKernelHaveEqualName<NameT>()) {
2234-
extractArgsAndReqs();
2235-
MKernelName = getKernelName();
2236-
} else
2237-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(
2238-
std::move(KernelFunc));
2239-
#endif
2153+
wrap_kernel<WrapAs::parallel_for, KernelName, LambdaArgType, Dims>(
2154+
KernelFunc, Kernel, {} /*Props*/, NDRange);
22402155
}
22412156

22422157
/// Hierarchical kernel invocation method of a kernel.
@@ -2256,26 +2171,12 @@ class __SYCL_EXPORT handler {
22562171
int Dims>
22572172
void parallel_for_work_group(kernel Kernel, range<Dims> NumWorkGroups,
22582173
const KernelType &KernelFunc) {
2259-
using NameT =
2260-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
22612174
using LambdaArgType =
22622175
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
2263-
(void)Kernel;
2264-
(void)NumWorkGroups;
2265-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2266-
LambdaArgType>::wrap(this, KernelFunc);
2267-
#ifndef __SYCL_DEVICE_ONLY__
2268-
throwIfActionIsCreated();
2269-
// Ignore any set kernel bundles and use the one associated with the kernel
2270-
setHandlerKernelBundle(Kernel);
2271-
verifyUsedKernelBundleInternal(
2272-
detail::string_view{detail::getKernelName<NameT>()});
2273-
detail::checkValueRange<Dims>(NumWorkGroups);
2274-
setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true);
2275-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2276-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
2277-
setType(detail::CGType::Kernel);
2278-
#endif // __SYCL_DEVICE_ONLY__
2176+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
2177+
Dims,
2178+
/*SetNumWorkGroups*/ true>(KernelFunc, Kernel, {} /*Props*/,
2179+
NumWorkGroups);
22792180
}
22802181

22812182
/// Hierarchical kernel invocation method of a kernel.
@@ -2298,29 +2199,12 @@ class __SYCL_EXPORT handler {
22982199
void parallel_for_work_group(kernel Kernel, range<Dims> NumWorkGroups,
22992200
range<Dims> WorkGroupSize,
23002201
const KernelType &KernelFunc) {
2301-
using NameT =
2302-
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
23032202
using LambdaArgType =
23042203
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
2305-
(void)Kernel;
2306-
(void)NumWorkGroups;
2307-
(void)WorkGroupSize;
2308-
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2309-
LambdaArgType>::wrap(this, KernelFunc);
2310-
#ifndef __SYCL_DEVICE_ONLY__
2311-
throwIfActionIsCreated();
2312-
// Ignore any set kernel bundles and use the one associated with the kernel
2313-
setHandlerKernelBundle(Kernel);
2314-
verifyUsedKernelBundleInternal(
2315-
detail::string_view{detail::getKernelName<NameT>()});
23162204
nd_range<Dims> ExecRange =
23172205
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
2318-
detail::checkValueRange<Dims>(ExecRange);
2319-
setNDRangeDescriptor(std::move(ExecRange));
2320-
MKernel = detail::getSyclObjImpl(std::move(Kernel));
2321-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
2322-
setType(detail::CGType::Kernel);
2323-
#endif // __SYCL_DEVICE_ONLY__
2206+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
2207+
Dims>(KernelFunc, Kernel, {} /*Props*/, ExecRange);
23242208
}
23252209

23262210
template <typename KernelName = detail::auto_name, typename KernelType,

0 commit comments

Comments
 (0)