Skip to content

Commit 6d64961

Browse files
[SYCL][NFCI] Use wrap_kernel directly in the user-facing API (#18065)
No value in thin wrappers requiring extra instantiation for every kernel.
1 parent 4009536 commit 6d64961

File tree

1 file changed

+25
-92
lines changed

1 file changed

+25
-92
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 25 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,35 +1259,6 @@ class __SYCL_EXPORT handler {
12591259
}
12601260
}
12611261

1262-
/// Defines and invokes a SYCL kernel function for the specified nd_range.
1263-
///
1264-
/// The SYCL kernel function is defined as a lambda function or a named
1265-
/// function object type and given an id or item for indexing in the indexing
1266-
/// space defined by range.
1267-
/// If it is a named function object and the function object type is
1268-
/// globally visible, there is no need for the developer to provide
1269-
/// a kernel name for it.
1270-
///
1271-
/// \param ExecutionRange is a ND-range defining global and local sizes as
1272-
/// well as offset.
1273-
/// \param Properties is the properties.
1274-
/// \param KernelFunc is a SYCL kernel function.
1275-
template <typename KernelName, typename KernelType, int Dims,
1276-
typename PropertiesT>
1277-
void parallel_for_impl(nd_range<Dims> ExecutionRange, PropertiesT Props,
1278-
const KernelType &KernelFunc) {
1279-
using LambdaArgType =
1280-
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
1281-
static_assert(
1282-
std::is_convertible_v<sycl::nd_item<Dims>, LambdaArgType>,
1283-
"Kernel argument of a sycl::parallel_for with sycl::nd_range "
1284-
"must be either sycl::nd_item or be convertible from sycl::nd_item");
1285-
using TransformedArgType = sycl::nd_item<Dims>;
1286-
1287-
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
1288-
KernelFunc, Props, ExecutionRange);
1289-
}
1290-
12911262
/// Defines and invokes a SYCL kernel function for the specified range.
12921263
///
12931264
/// The SYCL kernel function is defined as SYCL kernel object. The kernel
@@ -1337,56 +1308,6 @@ class __SYCL_EXPORT handler {
13371308
#endif
13381309
}
13391310

1340-
/// Hierarchical kernel invocation method of a kernel defined as a lambda
1341-
/// encoding the body of each work-group to launch.
1342-
///
1343-
/// Lambda may contain multiple calls to parallel_for_work_item(...) methods
1344-
/// representing the execution on each work-item. Launches NumWorkGroups
1345-
/// work-groups of runtime-defined size.
1346-
///
1347-
/// \param NumWorkGroups is a range describing the number of work-groups in
1348-
/// each dimension.
1349-
/// \param KernelFunc is a lambda representing kernel.
1350-
template <
1351-
typename KernelName, typename KernelType, int Dims,
1352-
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
1353-
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
1354-
PropertiesT Props,
1355-
const KernelType &KernelFunc) {
1356-
using LambdaArgType =
1357-
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1358-
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1359-
Dims,
1360-
/*SetNumWorkGroups=*/true>(KernelFunc, Props, NumWorkGroups);
1361-
}
1362-
1363-
/// Hierarchical kernel invocation method of a kernel defined as a lambda
1364-
/// encoding the body of each work-group to launch.
1365-
///
1366-
/// Lambda may contain multiple calls to parallel_for_work_item(...) methods
1367-
/// representing the execution on each work-item. Launches NumWorkGroups
1368-
/// work-groups of WorkGroupSize size.
1369-
///
1370-
/// \param NumWorkGroups is a range describing the number of work-groups in
1371-
/// each dimension.
1372-
/// \param WorkGroupSize is a range describing the size of work-groups in
1373-
/// each dimension.
1374-
/// \param KernelFunc is a lambda representing kernel.
1375-
template <
1376-
typename KernelName, typename KernelType, int Dims,
1377-
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
1378-
void parallel_for_work_group_lambda_impl(range<Dims> NumWorkGroups,
1379-
range<Dims> WorkGroupSize,
1380-
PropertiesT Props,
1381-
const KernelType &KernelFunc) {
1382-
using LambdaArgType =
1383-
sycl::detail::lambda_arg_type<KernelType, group<Dims>>;
1384-
nd_range<Dims> ExecRange =
1385-
nd_range<Dims>(NumWorkGroups * WorkGroupSize, WorkGroupSize);
1386-
wrap_kernel<WrapAs::parallel_for_work_group, KernelName, LambdaArgType,
1387-
Dims>(KernelFunc, Props, ExecRange);
1388-
}
1389-
13901311
#ifdef SYCL_LANGUAGE_VERSION
13911312
#ifndef __INTEL_SYCL_USE_INTEGRATION_HEADERS
13921313
#define __SYCL_KERNEL_ATTR__ [[clang::sycl_kernel_entry_point(KernelName)]]
@@ -1988,9 +1909,10 @@ class __SYCL_EXPORT handler {
19881909
int Dims>
19891910
void parallel_for_work_group(range<Dims> NumWorkGroups,
19901911
const KernelType &KernelFunc) {
1991-
parallel_for_work_group_lambda_impl<KernelName>(
1992-
NumWorkGroups, ext::oneapi::experimental::empty_properties_t{},
1993-
KernelFunc);
1912+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
1913+
detail::lambda_arg_type<KernelType, group<Dims>>, Dims,
1914+
/*SetNumWorkGroups=*/true>(KernelFunc, {} /*Props*/,
1915+
NumWorkGroups);
19941916
}
19951917

19961918
/// Hierarchical kernel invocation method of a kernel defined as a lambda
@@ -2010,9 +1932,10 @@ class __SYCL_EXPORT handler {
20101932
void parallel_for_work_group(range<Dims> NumWorkGroups,
20111933
range<Dims> WorkGroupSize,
20121934
const KernelType &KernelFunc) {
2013-
parallel_for_work_group_lambda_impl<KernelName>(
2014-
NumWorkGroups, WorkGroupSize,
2015-
ext::oneapi::experimental::empty_properties_t{}, KernelFunc);
1935+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
1936+
detail::lambda_arg_type<KernelType, group<Dims>>, Dims>(
1937+
KernelFunc, {} /*Props*/,
1938+
nd_range<Dims>{NumWorkGroups * WorkGroupSize, WorkGroupSize});
20161939
}
20171940

20181941
/// Invokes a SYCL kernel.
@@ -2300,7 +2223,16 @@ class __SYCL_EXPORT handler {
23002223
PropertiesT>::value> parallel_for(nd_range<Dims> Range,
23012224
PropertiesT Properties,
23022225
const KernelType &KernelFunc) {
2303-
parallel_for_impl<KernelName>(Range, Properties, std::move(KernelFunc));
2226+
using LambdaArgType =
2227+
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
2228+
static_assert(
2229+
std::is_convertible_v<sycl::nd_item<Dims>, LambdaArgType>,
2230+
"Kernel argument of a sycl::parallel_for with sycl::nd_range "
2231+
"must be either sycl::nd_item or be convertible from sycl::nd_item");
2232+
using TransformedArgType = sycl::nd_item<Dims>;
2233+
2234+
wrap_kernel<WrapAs::parallel_for, KernelName, TransformedArgType, Dims>(
2235+
KernelFunc, Properties, Range);
23042236
}
23052237

23062238
/// Reductions @{
@@ -2431,9 +2363,9 @@ class __SYCL_EXPORT handler {
24312363
"member function instead.")
24322364
void parallel_for_work_group(range<Dims> NumWorkGroups, PropertiesT Props,
24332365
const KernelType &KernelFunc) {
2434-
parallel_for_work_group_lambda_impl<KernelName, KernelType, Dims,
2435-
PropertiesT>(NumWorkGroups, Props,
2436-
KernelFunc);
2366+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
2367+
detail::lambda_arg_type<KernelType, group<Dims>>, Dims,
2368+
/*SetNumWorkGroups=*/true>(KernelFunc, Props, NumWorkGroups);
24372369
}
24382370

24392371
template <typename KernelName = detail::auto_name, typename KernelType,
@@ -2445,9 +2377,10 @@ class __SYCL_EXPORT handler {
24452377
void parallel_for_work_group(range<Dims> NumWorkGroups,
24462378
range<Dims> WorkGroupSize, PropertiesT Props,
24472379
const KernelType &KernelFunc) {
2448-
parallel_for_work_group_lambda_impl<KernelName, KernelType, Dims,
2449-
PropertiesT>(
2450-
NumWorkGroups, WorkGroupSize, Props, KernelFunc);
2380+
wrap_kernel<WrapAs::parallel_for_work_group, KernelName,
2381+
detail::lambda_arg_type<KernelType, group<Dims>>, Dims>(
2382+
KernelFunc, Props,
2383+
nd_range<Dims>{NumWorkGroups * WorkGroupSize, WorkGroupSize});
24512384
}
24522385

24532386
// Explicit copy operations API

0 commit comments

Comments
 (0)