Skip to content

Commit 5d1dedd

Browse files
[NFCI][SYCL] Refactor handler::unpack (#17838)
1 parent 7d565e7 commit 5d1dedd

File tree

1 file changed

+88
-129
lines changed

1 file changed

+88
-129
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 88 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,8 +1208,8 @@ class __SYCL_EXPORT handler {
12081208
using KName = std::conditional_t<std::is_same<KernelType, NameT>::value,
12091209
decltype(Wrapper), NameWT>;
12101210

1211-
kernel_parallel_for_wrapper<KName, TransformedArgType, decltype(Wrapper),
1212-
PropertiesT>(Wrapper);
1211+
KernelWrapper<WrapAs::parallel_for, KName, decltype(Wrapper),
1212+
TransformedArgType, PropertiesT>::wrap(this, Wrapper);
12131213
#ifndef __SYCL_DEVICE_ONLY__
12141214
verifyUsedKernelBundleInternal(
12151215
detail::string_view{detail::getKernelName<NameT>()});
@@ -1234,8 +1234,8 @@ class __SYCL_EXPORT handler {
12341234
#ifndef __SYCL_FORCE_PARALLEL_FOR_RANGE_ROUNDING__
12351235
// If parallel_for range rounding is forced then only range rounded
12361236
// kernel is generated
1237-
kernel_parallel_for_wrapper<NameT, TransformedArgType, KernelType,
1238-
PropertiesT>(KernelFunc);
1237+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
1238+
PropertiesT>::wrap(this, KernelFunc);
12391239
#ifndef __SYCL_DEVICE_ONLY__
12401240
verifyUsedKernelBundleInternal(
12411241
detail::string_view{detail::getKernelName<NameT>()});
@@ -1283,8 +1283,8 @@ class __SYCL_EXPORT handler {
12831283

12841284
(void)ExecutionRange;
12851285
(void)Props;
1286-
kernel_parallel_for_wrapper<NameT, TransformedArgType, KernelType,
1287-
PropertiesT>(KernelFunc);
1286+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
1287+
PropertiesT>::wrap(this, KernelFunc);
12881288
#ifndef __SYCL_DEVICE_ONLY__
12891289
throwIfActionIsCreated();
12901290
verifyUsedKernelBundleInternal(
@@ -1371,8 +1371,8 @@ class __SYCL_EXPORT handler {
13711371
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
13721372
(void)NumWorkGroups;
13731373
(void)Props;
1374-
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType, KernelType,
1375-
PropertiesT>(KernelFunc);
1374+
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1375+
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
13761376
#ifndef __SYCL_DEVICE_ONLY__
13771377
throwIfActionIsCreated();
13781378
verifyUsedKernelBundleInternal(
@@ -1413,8 +1413,8 @@ class __SYCL_EXPORT handler {
14131413
(void)NumWorkGroups;
14141414
(void)WorkGroupSize;
14151415
(void)Props;
1416-
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType, KernelType,
1417-
PropertiesT>(KernelFunc);
1416+
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
1417+
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
14181418
#ifndef __SYCL_DEVICE_ONLY__
14191419
throwIfActionIsCreated();
14201420
verifyUsedKernelBundleInternal(
@@ -1554,127 +1554,79 @@ class __SYCL_EXPORT handler {
15541554
#endif
15551555
}
15561556

1557-
template <typename... Props> struct KernelPropertiesUnpackerImpl {
1558-
// Just pass extra Props... as template parameters to the underlying
1559-
// Caller->* member functions. Don't have reflection so try to use
1560-
// templates as much as possible to reduce the amount of boilerplate code
1561-
// needed. All the type checks are expected to be done at the Caller's
1562-
// methods side.
1563-
1564-
template <typename... TypesToForward, typename... ArgsTy>
1565-
static void kernel_single_task_unpack(handler *h, ArgsTy &&...Args) {
1566-
h->kernel_single_task<TypesToForward..., Props...>(
1567-
std::forward<ArgsTy>(Args)...);
1568-
}
1569-
1570-
template <typename... TypesToForward, typename... ArgsTy>
1571-
static void kernel_parallel_for_unpack(handler *h, ArgsTy &&...Args) {
1572-
h->kernel_parallel_for<TypesToForward..., Props...>(
1573-
std::forward<ArgsTy>(Args)...);
1574-
}
1575-
1576-
template <typename... TypesToForward, typename... ArgsTy>
1577-
static void kernel_parallel_for_work_group_unpack(handler *h,
1578-
ArgsTy &&...Args) {
1579-
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(
1580-
std::forward<ArgsTy>(Args)...);
1581-
}
1582-
};
1583-
1584-
template <typename PropertiesT>
1585-
struct KernelPropertiesUnpacker : public KernelPropertiesUnpackerImpl<> {
1586-
// This should always fail outside the specialization below but must be
1587-
// dependent to avoid failing even if not instantiated.
1588-
static_assert(
1589-
ext::oneapi::experimental::is_property_list<PropertiesT>::value,
1590-
"Template type is not a property list.");
1591-
};
1592-
1593-
template <typename... Props>
1594-
struct KernelPropertiesUnpacker<
1595-
ext::oneapi::experimental::detail::properties_t<Props...>>
1596-
: public KernelPropertiesUnpackerImpl<Props...> {};
1597-
1598-
// Helper function to
1599-
//
1600-
// * Make use of the KernelPropertiesUnpacker above
1601-
// * Decide if we need an extra kernel_handler parameter
1557+
// The KernelWrapper below has two purposes.
16021558
//
1603-
// The interface uses a \p Lambda callback to propagate that information back
1604-
// to the caller as we need the caller to communicate:
1559+
// First, from SYCL 2020, Table 129 (Member functions of the `handler ` class)
1560+
// > The callable ... can optionally take a `kernel_handler` ... in
1561+
// which > case the SYCL runtime will construct an instance of
1562+
// `kernel_handler` > and pass it to the callable.
16051563
//
1606-
// * Name of the method to call
1607-
// * Provide explicit template type parameters for the call
1564+
// Note: "..." due to slight wording variability between
1565+
// single_task/parallel_for (e.g. only parameter vs last). This helper class
1566+
// calls `kernel_*` entry points (both hardcoded names known to FE and special
1567+
// device-specific entry point attributes) with proper arguments (with/without
1568+
// `kernel_handler` argument, depending on the signature of the SYCL kernel
1569+
// function).
16081570
//
1609-
// Couldn't think of a better way to achieve both.
1610-
template <typename KernelName, typename KernelType, typename PropertiesT,
1611-
bool HasKernelHandlerArg, typename FuncTy>
1612-
void unpack(const KernelType &KernelFunc, FuncTy Lambda) {
1613-
#ifdef __SYCL_DEVICE_ONLY__
1614-
detail::CheckDeviceCopyable<KernelType>();
1615-
#endif // __SYCL_DEVICE_ONLY__
1616-
using MergedPropertiesT =
1617-
typename detail::GetMergedKernelProperties<KernelType,
1618-
PropertiesT>::type;
1619-
using Unpacker = KernelPropertiesUnpacker<MergedPropertiesT>;
1620-
#ifndef __SYCL_DEVICE_ONLY__
1621-
// If there are properties provided by get method then process them.
1622-
if constexpr (ext::oneapi::experimental::detail::
1623-
HasKernelPropertiesGetMethod<const KernelType &>::value) {
1624-
processProperties<detail::isKernelESIMD<KernelName>()>(
1625-
KernelFunc.get(ext::oneapi::experimental::properties_tag{}));
1626-
}
1627-
#endif
1628-
if constexpr (HasKernelHandlerArg) {
1629-
kernel_handler KH;
1630-
Lambda(Unpacker{}, this, KernelFunc, KH);
1631-
} else {
1632-
Lambda(Unpacker{}, this, KernelFunc);
1633-
}
1634-
}
1571+
// Second, it performs a few checks and some properties processing (including
1572+
// the one provided via `sycl_ext_oneapi_kernel_properties` extension by
1573+
// embedding them into the kernel's type).
16351574

1636-
// NOTE: to support kernel_handler argument in kernel lambdas, only
1637-
// kernel_***_wrapper functions must be called in this code
1575+
enum class WrapAs { single_task, parallel_for, parallel_for_work_group };
16381576

16391577
template <
1640-
typename KernelName, typename KernelType,
1641-
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
1642-
void kernel_single_task_wrapper(const KernelType &KernelFunc) {
1643-
unpack<KernelName, KernelType, PropertiesT,
1644-
detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
1645-
KernelFunc, [&](auto Unpacker, auto &&...args) {
1646-
Unpacker.template kernel_single_task_unpack<KernelName, KernelType>(
1578+
WrapAs WrapAsVal, typename KernelName, typename KernelType,
1579+
typename ElementType,
1580+
typename PropertiesT = ext::oneapi::experimental::empty_properties_t,
1581+
typename MergedPropertiesT = typename detail::GetMergedKernelProperties<
1582+
KernelType, PropertiesT>::type>
1583+
struct KernelWrapper;
1584+
template <WrapAs WrapAsVal, typename KernelName, typename KernelType,
1585+
typename ElementType, typename PropertiesT, typename... MergedProps>
1586+
struct KernelWrapper<
1587+
WrapAsVal, KernelName, KernelType, ElementType, PropertiesT,
1588+
ext::oneapi::experimental::detail::properties_t<MergedProps...>> {
1589+
static void wrap(handler *h, const KernelType &KernelFunc) {
1590+
#ifdef __SYCL_DEVICE_ONLY__
1591+
detail::CheckDeviceCopyable<KernelType>();
1592+
#else
1593+
// If there are properties provided by get method then process them.
1594+
if constexpr (ext::oneapi::experimental::detail::
1595+
HasKernelPropertiesGetMethod<
1596+
const KernelType &>::value) {
1597+
h->processProperties<detail::isKernelESIMD<KernelName>()>(
1598+
KernelFunc.get(ext::oneapi::experimental::properties_tag{}));
1599+
}
1600+
#endif
1601+
auto L = [&](auto &&...args) {
1602+
if constexpr (WrapAsVal == WrapAs::single_task) {
1603+
h->kernel_single_task<KernelName, KernelType, MergedProps...>(
16471604
std::forward<decltype(args)>(args)...);
1648-
});
1649-
}
1650-
1651-
template <
1652-
typename KernelName, typename ElementType, typename KernelType,
1653-
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
1654-
void kernel_parallel_for_wrapper(const KernelType &KernelFunc) {
1655-
unpack<KernelName, KernelType, PropertiesT,
1656-
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1657-
ElementType>::value>(
1658-
KernelFunc, [&](auto Unpacker, auto &&...args) {
1659-
Unpacker.template kernel_parallel_for_unpack<KernelName, ElementType,
1660-
KernelType>(
1605+
} else if constexpr (WrapAsVal == WrapAs::parallel_for) {
1606+
h->kernel_parallel_for<KernelName, ElementType, KernelType,
1607+
MergedProps...>(
16611608
std::forward<decltype(args)>(args)...);
1662-
});
1663-
}
1664-
1665-
template <
1666-
typename KernelName, typename ElementType, typename KernelType,
1667-
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
1668-
void kernel_parallel_for_work_group_wrapper(const KernelType &KernelFunc) {
1669-
unpack<KernelName, KernelType, PropertiesT,
1670-
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1671-
ElementType>::value>(
1672-
KernelFunc, [&](auto Unpacker, auto &&...args) {
1673-
Unpacker.template kernel_parallel_for_work_group_unpack<
1674-
KernelName, ElementType, KernelType>(
1609+
} else if constexpr (WrapAsVal == WrapAs::parallel_for_work_group) {
1610+
h->kernel_parallel_for_work_group<KernelName, ElementType, KernelType,
1611+
MergedProps...>(
16751612
std::forward<decltype(args)>(args)...);
1676-
});
1677-
}
1613+
} else {
1614+
// Always false, but template-dependent.
1615+
static_assert(WrapAsVal != WrapAsVal, "Unexpected WrapAsVal");
1616+
}
1617+
};
1618+
if constexpr (detail::KernelLambdaHasKernelHandlerArgT<
1619+
KernelType, ElementType>::value) {
1620+
kernel_handler KH;
1621+
L(KernelFunc, KH);
1622+
} else {
1623+
L(KernelFunc);
1624+
}
1625+
}
1626+
};
1627+
1628+
// NOTE: to support kernel_handler argument in kernel lambdas, only
1629+
// KernelWrapper<...>::wrap() must be called in this code.
16781630

16791631
/// Defines and invokes a SYCL kernel function as a function object type.
16801632
///
@@ -1694,7 +1646,8 @@ class __SYCL_EXPORT handler {
16941646
using NameT =
16951647
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
16961648

1697-
kernel_single_task_wrapper<NameT, KernelType, PropertiesT>(KernelFunc);
1649+
KernelWrapper<WrapAs::single_task, NameT, KernelType, void,
1650+
PropertiesT>::wrap(this, KernelFunc);
16981651
#ifndef __SYCL_DEVICE_ONLY__
16991652
throwIfActionIsCreated();
17001653
throwOnKernelParameterMisuse<KernelName, KernelType>();
@@ -1997,7 +1950,8 @@ class __SYCL_EXPORT handler {
19971950
typename TransformUserItemType<Dims, LambdaArgType>::type>;
19981951
(void)NumWorkItems;
19991952
(void)WorkItemOffset;
2000-
kernel_parallel_for_wrapper<NameT, TransformedArgType>(KernelFunc);
1953+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType,
1954+
TransformedArgType>::wrap(this, KernelFunc);
20011955
#ifndef __SYCL_DEVICE_ONLY__
20021956
throwIfActionIsCreated();
20031957
verifyUsedKernelBundleInternal(
@@ -2173,7 +2127,8 @@ class __SYCL_EXPORT handler {
21732127
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
21742128
(void)Kernel;
21752129
(void)NumWorkItems;
2176-
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2130+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2131+
this, KernelFunc);
21772132
#ifndef __SYCL_DEVICE_ONLY__
21782133
throwIfActionIsCreated();
21792134
verifyUsedKernelBundleInternal(
@@ -2211,7 +2166,8 @@ class __SYCL_EXPORT handler {
22112166
(void)Kernel;
22122167
(void)NumWorkItems;
22132168
(void)WorkItemOffset;
2214-
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2169+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2170+
this, KernelFunc);
22152171
#ifndef __SYCL_DEVICE_ONLY__
22162172
throwIfActionIsCreated();
22172173
// Ignore any set kernel bundles and use the one associated with the kernel
@@ -2250,7 +2206,8 @@ class __SYCL_EXPORT handler {
22502206
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
22512207
(void)Kernel;
22522208
(void)NDRange;
2253-
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
2209+
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
2210+
this, KernelFunc);
22542211
#ifndef __SYCL_DEVICE_ONLY__
22552212
throwIfActionIsCreated();
22562213
// Ignore any set kernel bundles and use the one associated with the kernel
@@ -2293,7 +2250,8 @@ class __SYCL_EXPORT handler {
22932250
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
22942251
(void)Kernel;
22952252
(void)NumWorkGroups;
2296-
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
2253+
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2254+
LambdaArgType>::wrap(this, KernelFunc);
22972255
#ifndef __SYCL_DEVICE_ONLY__
22982256
throwIfActionIsCreated();
22992257
// Ignore any set kernel bundles and use the one associated with the kernel
@@ -2335,7 +2293,8 @@ class __SYCL_EXPORT handler {
23352293
(void)Kernel;
23362294
(void)NumWorkGroups;
23372295
(void)WorkGroupSize;
2338-
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
2296+
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
2297+
LambdaArgType>::wrap(this, KernelFunc);
23392298
#ifndef __SYCL_DEVICE_ONLY__
23402299
throwIfActionIsCreated();
23412300
// Ignore any set kernel bundles and use the one associated with the kernel

0 commit comments

Comments
 (0)