Skip to content

[NFCI][SYCL] Refactor handler::unpack #17838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 88 additions & 129 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,8 +1206,8 @@ class __SYCL_EXPORT handler {
using KName = std::conditional_t<std::is_same<KernelType, NameT>::value,
decltype(Wrapper), NameWT>;

kernel_parallel_for_wrapper<KName, TransformedArgType, decltype(Wrapper),
PropertiesT>(Wrapper);
KernelWrapper<WrapAs::parallel_for, KName, decltype(Wrapper),
TransformedArgType, PropertiesT>::wrap(this, Wrapper);
#ifndef __SYCL_DEVICE_ONLY__
verifyUsedKernelBundleInternal(
detail::string_view{detail::getKernelName<NameT>()});
Expand All @@ -1232,8 +1232,8 @@ class __SYCL_EXPORT handler {
#ifndef __SYCL_FORCE_PARALLEL_FOR_RANGE_ROUNDING__
// If parallel_for range rounding is forced then only range rounded
// kernel is generated
kernel_parallel_for_wrapper<NameT, TransformedArgType, KernelType,
PropertiesT>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
PropertiesT>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
verifyUsedKernelBundleInternal(
detail::string_view{detail::getKernelName<NameT>()});
Expand Down Expand Up @@ -1281,8 +1281,8 @@ class __SYCL_EXPORT handler {

(void)ExecutionRange;
(void)Props;
kernel_parallel_for_wrapper<NameT, TransformedArgType, KernelType,
PropertiesT>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, TransformedArgType,
PropertiesT>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
verifyUsedKernelBundleInternal(
Expand Down Expand Up @@ -1369,8 +1369,8 @@ class __SYCL_EXPORT handler {
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
(void)NumWorkGroups;
(void)Props;
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType, KernelType,
PropertiesT>(KernelFunc);
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
verifyUsedKernelBundleInternal(
Expand Down Expand Up @@ -1411,8 +1411,8 @@ class __SYCL_EXPORT handler {
(void)NumWorkGroups;
(void)WorkGroupSize;
(void)Props;
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType, KernelType,
PropertiesT>(KernelFunc);
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
LambdaArgType, PropertiesT>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
verifyUsedKernelBundleInternal(
Expand Down Expand Up @@ -1552,127 +1552,79 @@ class __SYCL_EXPORT handler {
#endif
}

template <typename... Props> struct KernelPropertiesUnpackerImpl {
// Just pass extra Props... as template parameters to the underlying
// Caller->* member functions. Don't have reflection so try to use
// templates as much as possible to reduce the amount of boilerplate code
// needed. All the type checks are expected to be done at the Caller's
// methods side.

template <typename... TypesToForward, typename... ArgsTy>
static void kernel_single_task_unpack(handler *h, ArgsTy &&...Args) {
h->kernel_single_task<TypesToForward..., Props...>(
std::forward<ArgsTy>(Args)...);
}

template <typename... TypesToForward, typename... ArgsTy>
static void kernel_parallel_for_unpack(handler *h, ArgsTy &&...Args) {
h->kernel_parallel_for<TypesToForward..., Props...>(
std::forward<ArgsTy>(Args)...);
}

template <typename... TypesToForward, typename... ArgsTy>
static void kernel_parallel_for_work_group_unpack(handler *h,
ArgsTy &&...Args) {
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(
std::forward<ArgsTy>(Args)...);
}
};

template <typename PropertiesT>
struct KernelPropertiesUnpacker : public KernelPropertiesUnpackerImpl<> {
// This should always fail outside the specialization below but must be
// dependent to avoid failing even if not instantiated.
static_assert(
ext::oneapi::experimental::is_property_list<PropertiesT>::value,
"Template type is not a property list.");
};

template <typename... Props>
struct KernelPropertiesUnpacker<
ext::oneapi::experimental::detail::properties_t<Props...>>
: public KernelPropertiesUnpackerImpl<Props...> {};

// Helper function to
//
// * Make use of the KernelPropertiesUnpacker above
// * Decide if we need an extra kernel_handler parameter
// The KernelWrapper below has two purposes.
//
// The interface uses a \p Lambda callback to propagate that information back
// to the caller as we need the caller to communicate:
// First, from SYCL 2020, Table 129 (Member functions of the `handler ` class)
// > The callable ... can optionally take a `kernel_handler` ... in
// which > case the SYCL runtime will construct an instance of
// `kernel_handler` > and pass it to the callable.
//
// * Name of the method to call
// * Provide explicit template type parameters for the call
// Note: "..." due to slight wording variability between
// single_task/parallel_for (e.g. only parameter vs last). This helper class
// calls `kernel_*` entry points (both hardcoded names known to FE and special
// device-specific entry point attributes) with proper arguments (with/without
// `kernel_handler` argument, depending on the signature of the SYCL kernel
// function).
//
// Couldn't think of a better way to achieve both.
template <typename KernelName, typename KernelType, typename PropertiesT,
bool HasKernelHandlerArg, typename FuncTy>
void unpack(const KernelType &KernelFunc, FuncTy Lambda) {
#ifdef __SYCL_DEVICE_ONLY__
detail::CheckDeviceCopyable<KernelType>();
#endif // __SYCL_DEVICE_ONLY__
using MergedPropertiesT =
typename detail::GetMergedKernelProperties<KernelType,
PropertiesT>::type;
using Unpacker = KernelPropertiesUnpacker<MergedPropertiesT>;
#ifndef __SYCL_DEVICE_ONLY__
// If there are properties provided by get method then process them.
if constexpr (ext::oneapi::experimental::detail::
HasKernelPropertiesGetMethod<const KernelType &>::value) {
processProperties<detail::isKernelESIMD<KernelName>()>(
KernelFunc.get(ext::oneapi::experimental::properties_tag{}));
}
#endif
if constexpr (HasKernelHandlerArg) {
kernel_handler KH;
Lambda(Unpacker{}, this, KernelFunc, KH);
} else {
Lambda(Unpacker{}, this, KernelFunc);
}
}
// Second, it performs a few checks and some properties processing (including
// the one provided via `sycl_ext_oneapi_kernel_properties` extension by
// embedding them into the kernel's type).

// NOTE: to support kernel_handler argument in kernel lambdas, only
// kernel_***_wrapper functions must be called in this code
enum class WrapAs { single_task, parallel_for, parallel_for_work_group };

template <
typename KernelName, typename KernelType,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
void kernel_single_task_wrapper(const KernelType &KernelFunc) {
unpack<KernelName, KernelType, PropertiesT,
detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
KernelFunc, [&](auto Unpacker, auto &&...args) {
Unpacker.template kernel_single_task_unpack<KernelName, KernelType>(
WrapAs WrapAsVal, typename KernelName, typename KernelType,
typename ElementType,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t,
typename MergedPropertiesT = typename detail::GetMergedKernelProperties<
KernelType, PropertiesT>::type>
struct KernelWrapper;
template <WrapAs WrapAsVal, typename KernelName, typename KernelType,
typename ElementType, typename PropertiesT, typename... MergedProps>
struct KernelWrapper<
WrapAsVal, KernelName, KernelType, ElementType, PropertiesT,
ext::oneapi::experimental::detail::properties_t<MergedProps...>> {
static void wrap(handler *h, const KernelType &KernelFunc) {
#ifdef __SYCL_DEVICE_ONLY__
detail::CheckDeviceCopyable<KernelType>();
#else
// If there are properties provided by get method then process them.
if constexpr (ext::oneapi::experimental::detail::
HasKernelPropertiesGetMethod<
const KernelType &>::value) {
h->processProperties<detail::isKernelESIMD<KernelName>()>(
KernelFunc.get(ext::oneapi::experimental::properties_tag{}));
}
#endif
auto L = [&](auto &&...args) {
if constexpr (WrapAsVal == WrapAs::single_task) {
h->kernel_single_task<KernelName, KernelType, MergedProps...>(
std::forward<decltype(args)>(args)...);
});
}

template <
typename KernelName, typename ElementType, typename KernelType,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
void kernel_parallel_for_wrapper(const KernelType &KernelFunc) {
unpack<KernelName, KernelType, PropertiesT,
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
ElementType>::value>(
KernelFunc, [&](auto Unpacker, auto &&...args) {
Unpacker.template kernel_parallel_for_unpack<KernelName, ElementType,
KernelType>(
} else if constexpr (WrapAsVal == WrapAs::parallel_for) {
h->kernel_parallel_for<KernelName, ElementType, KernelType,
MergedProps...>(
std::forward<decltype(args)>(args)...);
});
}

template <
typename KernelName, typename ElementType, typename KernelType,
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
void kernel_parallel_for_work_group_wrapper(const KernelType &KernelFunc) {
unpack<KernelName, KernelType, PropertiesT,
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
ElementType>::value>(
KernelFunc, [&](auto Unpacker, auto &&...args) {
Unpacker.template kernel_parallel_for_work_group_unpack<
KernelName, ElementType, KernelType>(
} else if constexpr (WrapAsVal == WrapAs::parallel_for_work_group) {
h->kernel_parallel_for_work_group<KernelName, ElementType, KernelType,
MergedProps...>(
std::forward<decltype(args)>(args)...);
});
}
} else {
// Always false, but template-dependent.
static_assert(WrapAsVal != WrapAsVal, "Unexpected WrapAsVal");
}
};
if constexpr (detail::KernelLambdaHasKernelHandlerArgT<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gah, clang-format is just the worst.

KernelType, ElementType>::value) {
kernel_handler KH;
L(KernelFunc, KH);
} else {
L(KernelFunc);
}
}
};

// NOTE: to support kernel_handler argument in kernel lambdas, only
// KernelWrapper<...>::wrap() must be called in this code.

/// Defines and invokes a SYCL kernel function as a function object type.
///
Expand All @@ -1692,7 +1644,8 @@ class __SYCL_EXPORT handler {
using NameT =
typename detail::get_kernel_name_t<KernelName, KernelType>::name;

kernel_single_task_wrapper<NameT, KernelType, PropertiesT>(KernelFunc);
KernelWrapper<WrapAs::single_task, NameT, KernelType, void,
PropertiesT>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
throwOnKernelParameterMisuse<KernelName, KernelType>();
Expand Down Expand Up @@ -1995,7 +1948,8 @@ class __SYCL_EXPORT handler {
typename TransformUserItemType<Dims, LambdaArgType>::type>;
(void)NumWorkItems;
(void)WorkItemOffset;
kernel_parallel_for_wrapper<NameT, TransformedArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType,
TransformedArgType>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
verifyUsedKernelBundleInternal(
Expand Down Expand Up @@ -2171,7 +2125,8 @@ class __SYCL_EXPORT handler {
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
(void)Kernel;
(void)NumWorkItems;
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
verifyUsedKernelBundleInternal(
Expand Down Expand Up @@ -2209,7 +2164,8 @@ class __SYCL_EXPORT handler {
(void)Kernel;
(void)NumWorkItems;
(void)WorkItemOffset;
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
// Ignore any set kernel bundles and use the one associated with the kernel
Expand Down Expand Up @@ -2248,7 +2204,8 @@ class __SYCL_EXPORT handler {
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
(void)Kernel;
(void)NDRange;
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for, NameT, KernelType, LambdaArgType>::wrap(
this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
// Ignore any set kernel bundles and use the one associated with the kernel
Expand Down Expand Up @@ -2291,7 +2248,8 @@ class __SYCL_EXPORT handler {
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
(void)Kernel;
(void)NumWorkGroups;
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
LambdaArgType>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
// Ignore any set kernel bundles and use the one associated with the kernel
Expand Down Expand Up @@ -2333,7 +2291,8 @@ class __SYCL_EXPORT handler {
(void)Kernel;
(void)NumWorkGroups;
(void)WorkGroupSize;
kernel_parallel_for_work_group_wrapper<NameT, LambdaArgType>(KernelFunc);
KernelWrapper<WrapAs::parallel_for_work_group, NameT, KernelType,
LambdaArgType>::wrap(this, KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
// Ignore any set kernel bundles and use the one associated with the kernel
Expand Down