Skip to content

Commit 20dbc70

Browse files
[NFC][SYCL] Minor refactor around KernelPropertiesUnpacker (#7125)
1 parent 0e32a28 commit 20dbc70

File tree

1 file changed

+75
-162
lines changed

1 file changed

+75
-162
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 75 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,200 +1347,113 @@ class __SYCL_EXPORT handler {
13471347
#endif
13481348
}
13491349

1350-
template <typename PropertiesT> struct KernelPropertiesUnpacker {
1351-
template <typename KernelName, typename KernelType>
1352-
static void kernel_single_task_unpack(handler *, _KERNELFUNCPARAMTYPE) {}
1353-
1354-
template <typename KernelName, typename KernelType>
1355-
static void kernel_single_task_unpack(handler *, _KERNELFUNCPARAMTYPE,
1356-
kernel_handler) {}
1357-
1358-
template <typename KernelName, typename ElementType, typename KernelType>
1359-
static void kernel_parallel_for_unpack(handler *, _KERNELFUNCPARAMTYPE) {}
1360-
1361-
template <typename KernelName, typename ElementType, typename KernelType>
1362-
static void kernel_parallel_for_unpack(handler *, _KERNELFUNCPARAMTYPE,
1363-
kernel_handler) {}
1350+
template <typename... Props> struct KernelPropertiesUnpackerImpl {
1351+
// Just pass extra Props... as template parameters to the underlying
1352+
// Caller->* member functions. Don't have reflection so try to use
1353+
// templates as much as possible to reduce the amount of boilerplate code
1354+
// needed. All the type checks are expected to be done at the Caller's
1355+
// methods side.
1356+
1357+
template <typename... TypesToForward, typename... ArgsTy>
1358+
static void kernel_single_task_unpack(handler *h, ArgsTy... Args) {
1359+
h->kernel_single_task<TypesToForward..., Props...>(Args...);
1360+
}
13641361

1365-
template <typename KernelName, typename ElementType, typename KernelType>
1366-
static void kernel_parallel_for_work_group_unpack(handler *,
1367-
_KERNELFUNCPARAMTYPE) {}
1362+
template <typename... TypesToForward, typename... ArgsTy>
1363+
static void kernel_parallel_for_unpack(handler *h, ArgsTy... Args) {
1364+
h->kernel_parallel_for<TypesToForward..., Props...>(Args...);
1365+
}
13681366

1369-
template <typename KernelName, typename ElementType, typename KernelType>
1370-
static void kernel_parallel_for_work_group_unpack(handler *,
1371-
_KERNELFUNCPARAMTYPE,
1372-
kernel_handler) {}
1367+
template <typename... TypesToForward, typename... ArgsTy>
1368+
static void kernel_parallel_for_work_group_unpack(handler *h,
1369+
ArgsTy... Args) {
1370+
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(Args...);
1371+
}
1372+
};
13731373

1374-
// This should always fail but must be dependent to avoid always failing.
1375-
// It is defined after the shell members to avoid that they are stripped
1376-
// from the class.
1374+
template <typename PropertiesT>
1375+
struct KernelPropertiesUnpacker : public KernelPropertiesUnpackerImpl<> {
1376+
// This should always fail outside the specialization below but must be
1377+
// dependent to avoid failing even if not instantiated.
13771378
static_assert(
13781379
ext::oneapi::experimental::is_property_list<PropertiesT>::value,
13791380
"Template type is not a property list.");
13801381
};
13811382

13821383
template <typename... Props>
13831384
struct KernelPropertiesUnpacker<
1384-
ext::oneapi::experimental::detail::properties_t<Props...>> {
1385-
template <typename KernelName, typename KernelType>
1386-
static void kernel_single_task_unpack(handler *Caller,
1387-
_KERNELFUNCPARAM(KernelFunc)) {
1388-
Caller->kernel_single_task<KernelName, KernelType, Props...>(KernelFunc);
1389-
}
1390-
1391-
template <typename KernelName, typename KernelType>
1392-
static void kernel_single_task_unpack(handler *Caller,
1393-
_KERNELFUNCPARAM(KernelFunc),
1394-
kernel_handler KH) {
1395-
Caller->kernel_single_task<KernelName, KernelType, Props...>(KernelFunc,
1396-
KH);
1397-
}
1398-
1399-
template <typename KernelName, typename ElementType, typename KernelType>
1400-
static void kernel_parallel_for_unpack(handler *Caller,
1401-
_KERNELFUNCPARAM(KernelFunc)) {
1402-
Caller
1403-
->kernel_parallel_for<KernelName, ElementType, KernelType, Props...>(
1404-
KernelFunc);
1405-
}
1406-
1407-
template <typename KernelName, typename ElementType, typename KernelType>
1408-
static void kernel_parallel_for_unpack(handler *Caller,
1409-
_KERNELFUNCPARAM(KernelFunc),
1410-
kernel_handler KH) {
1411-
Caller
1412-
->kernel_parallel_for<KernelName, ElementType, KernelType, Props...>(
1413-
KernelFunc, KH);
1414-
}
1415-
1416-
template <typename KernelName, typename ElementType, typename KernelType>
1417-
static void
1418-
kernel_parallel_for_work_group_unpack(handler *Caller,
1419-
_KERNELFUNCPARAM(KernelFunc)) {
1420-
Caller->kernel_parallel_for_work_group<KernelName, ElementType,
1421-
KernelType, Props...>(KernelFunc);
1422-
}
1385+
ext::oneapi::experimental::detail::properties_t<Props...>>
1386+
: public KernelPropertiesUnpackerImpl<Props...> {};
14231387

1424-
template <typename KernelName, typename ElementType, typename KernelType>
1425-
static void kernel_parallel_for_work_group_unpack(
1426-
handler *Caller, _KERNELFUNCPARAM(KernelFunc), kernel_handler KH) {
1427-
Caller->kernel_parallel_for_work_group<KernelName, ElementType,
1428-
KernelType, Props...>(KernelFunc,
1429-
KH);
1430-
}
1431-
};
1432-
1433-
// Wrappers for kernel_*** functions above with and without support of
1434-
// additional kernel_handler argument.
1435-
1436-
// NOTE: to support kernel_handler argument in kernel lambdas, only
1437-
// kernel_***_wrapper functions must be called in this code
1438-
1439-
// Wrappers for kernel_single_task(...)
1440-
1441-
template <typename KernelName, typename KernelType,
1442-
typename PropertiesT =
1443-
ext::oneapi::experimental::detail::empty_properties_t>
1444-
std::enable_if_t<detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>
1445-
kernel_single_task_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1446-
#ifdef __SYCL_DEVICE_ONLY__
1447-
detail::CheckDeviceCopyable<KernelType>();
1448-
#endif // __SYCL_DEVICE_ONLY__
1449-
kernel_handler KH;
1450-
using MergedPropertiesT =
1451-
typename detail::GetMergedKernelProperties<KernelType,
1452-
PropertiesT>::type;
1453-
KernelPropertiesUnpacker<MergedPropertiesT>::
1454-
template kernel_single_task_unpack<KernelName>(this, KernelFunc, KH);
1455-
}
1456-
1457-
template <typename KernelName, typename KernelType,
1458-
typename PropertiesT =
1459-
ext::oneapi::experimental::detail::empty_properties_t>
1460-
std::enable_if_t<!detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>
1461-
kernel_single_task_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1388+
// Helper function to
1389+
//
1390+
// * Make use of the KernelPropertiesUnpacker above
1391+
// * Decide if we need an extra kernel_handler parameter
1392+
//
1393+
// The interface uses a \p Lambda callback to propagate that information back
1394+
// to the caller as we need the caller to communicate:
1395+
//
1396+
// * Name of the method to call
1397+
// * Provide explicit template type parameters for the call
1398+
//
1399+
// Couldn't think of a better way to achieve both.
1400+
template <typename KernelType, typename PropertiesT, bool HasKernelHandlerArg,
1401+
typename FuncTy>
1402+
void unpack(_KERNELFUNCPARAM(KernelFunc), FuncTy Lambda) {
14621403
#ifdef __SYCL_DEVICE_ONLY__
14631404
detail::CheckDeviceCopyable<KernelType>();
14641405
#endif // __SYCL_DEVICE_ONLY__
14651406
using MergedPropertiesT =
14661407
typename detail::GetMergedKernelProperties<KernelType,
14671408
PropertiesT>::type;
1468-
KernelPropertiesUnpacker<MergedPropertiesT>::
1469-
template kernel_single_task_unpack<KernelName>(this, KernelFunc);
1409+
using Unpacker = KernelPropertiesUnpacker<MergedPropertiesT>;
1410+
if constexpr (HasKernelHandlerArg) {
1411+
kernel_handler KH;
1412+
Lambda(Unpacker{}, this, KernelFunc, KH);
1413+
} else {
1414+
Lambda(Unpacker{}, this, KernelFunc);
1415+
}
14701416
}
14711417

1472-
// Wrappers for kernel_parallel_for(...)
1473-
1474-
template <typename KernelName, typename ElementType, typename KernelType,
1475-
typename PropertiesT =
1476-
ext::oneapi::experimental::detail::empty_properties_t>
1477-
std::enable_if_t<
1478-
detail::KernelLambdaHasKernelHandlerArgT<KernelType, ElementType>::value>
1479-
kernel_parallel_for_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1480-
#ifdef __SYCL_DEVICE_ONLY__
1481-
detail::CheckDeviceCopyable<KernelType>();
1482-
#endif // __SYCL_DEVICE_ONLY__
1483-
kernel_handler KH;
1484-
using MergedPropertiesT =
1485-
typename detail::GetMergedKernelProperties<KernelType,
1486-
PropertiesT>::type;
1487-
KernelPropertiesUnpacker<MergedPropertiesT>::
1488-
template kernel_parallel_for_unpack<KernelName, ElementType>(
1489-
this, KernelFunc, KH);
1490-
}
1418+
// NOTE: to support kernel_handler argument in kernel lambdas, only
1419+
// kernel_***_wrapper functions must be called in this code
14911420

1492-
template <typename KernelName, typename ElementType, typename KernelType,
1421+
template <typename KernelName, typename KernelType,
14931422
typename PropertiesT =
14941423
ext::oneapi::experimental::detail::empty_properties_t>
1495-
std::enable_if_t<
1496-
!detail::KernelLambdaHasKernelHandlerArgT<KernelType, ElementType>::value>
1497-
kernel_parallel_for_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1498-
#ifdef __SYCL_DEVICE_ONLY__
1499-
detail::CheckDeviceCopyable<KernelType>();
1500-
#endif // __SYCL_DEVICE_ONLY__
1501-
using MergedPropertiesT =
1502-
typename detail::GetMergedKernelProperties<KernelType,
1503-
PropertiesT>::type;
1504-
KernelPropertiesUnpacker<MergedPropertiesT>::
1505-
template kernel_parallel_for_unpack<KernelName, ElementType>(
1506-
this, KernelFunc);
1424+
void kernel_single_task_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1425+
unpack<KernelType, PropertiesT,
1426+
detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
1427+
KernelFunc, [&](auto Unpacker, auto... args) {
1428+
Unpacker.template kernel_single_task_unpack<KernelName, KernelType>(
1429+
args...);
1430+
});
15071431
}
15081432

1509-
// Wrappers for kernel_parallel_for_work_group(...)
1510-
15111433
template <typename KernelName, typename ElementType, typename KernelType,
15121434
typename PropertiesT =
15131435
ext::oneapi::experimental::detail::empty_properties_t>
1514-
std::enable_if_t<
1515-
detail::KernelLambdaHasKernelHandlerArgT<KernelType, ElementType>::value>
1516-
kernel_parallel_for_work_group_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1517-
#ifdef __SYCL_DEVICE_ONLY__
1518-
detail::CheckDeviceCopyable<KernelType>();
1519-
#endif // __SYCL_DEVICE_ONLY__
1520-
kernel_handler KH;
1521-
using MergedPropertiesT =
1522-
typename detail::GetMergedKernelProperties<KernelType,
1523-
PropertiesT>::type;
1524-
KernelPropertiesUnpacker<MergedPropertiesT>::
1525-
template kernel_parallel_for_work_group_unpack<KernelName, ElementType>(
1526-
this, KernelFunc, KH);
1436+
void kernel_parallel_for_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1437+
unpack<KernelType, PropertiesT,
1438+
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1439+
ElementType>::value>(
1440+
KernelFunc, [&](auto Unpacker, auto... args) {
1441+
Unpacker.template kernel_parallel_for_unpack<KernelName, ElementType,
1442+
KernelType>(args...);
1443+
});
15271444
}
15281445

15291446
template <typename KernelName, typename ElementType, typename KernelType,
15301447
typename PropertiesT =
15311448
ext::oneapi::experimental::detail::empty_properties_t>
1532-
std::enable_if_t<
1533-
!detail::KernelLambdaHasKernelHandlerArgT<KernelType, ElementType>::value>
1534-
kernel_parallel_for_work_group_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1535-
#ifdef __SYCL_DEVICE_ONLY__
1536-
detail::CheckDeviceCopyable<KernelType>();
1537-
#endif // __SYCL_DEVICE_ONLY__
1538-
using MergedPropertiesT =
1539-
typename detail::GetMergedKernelProperties<KernelType,
1540-
PropertiesT>::type;
1541-
KernelPropertiesUnpacker<MergedPropertiesT>::
1542-
template kernel_parallel_for_work_group_unpack<KernelName, ElementType>(
1543-
this, KernelFunc);
1449+
void kernel_parallel_for_work_group_wrapper(_KERNELFUNCPARAM(KernelFunc)) {
1450+
unpack<KernelType, PropertiesT,
1451+
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
1452+
ElementType>::value>(
1453+
KernelFunc, [&](auto Unpacker, auto... args) {
1454+
Unpacker.template kernel_parallel_for_work_group_unpack<
1455+
KernelName, ElementType, KernelType>(args...);
1456+
});
15441457
}
15451458

15461459
/// Defines and invokes a SYCL kernel function as a function object type.

0 commit comments

Comments
 (0)