Skip to content

Commit 5f82ac6

Browse files
romanovvladErich Keane
authored andcommitted
[SYCL] Use if constexpr instead of enable_if to avoid msvc bug
1 parent 72738e0 commit 5f82ac6

File tree

2 files changed

+58
-98
lines changed

2 files changed

+58
-98
lines changed

sycl/include/CL/sycl/detail/cg_types.hpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -184,34 +184,24 @@ constexpr bool isKernelLambdaCallableWithKernelHandler() {
184184

185185
// Helpers for running kernel lambda on the host device
186186

187-
template <typename KernelType,
188-
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
189-
KernelType>()> * = nullptr>
190-
constexpr void runKernelWithoutArg(KernelType KernelName) {
191-
kernel_handler KH;
192-
KernelName(KH);
193-
}
194-
195-
template <typename KernelType,
196-
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
197-
KernelType>()> * = nullptr>
198-
constexpr void runKernelWithoutArg(KernelType KernelName) {
199-
KernelName();
200-
}
201-
202-
template <typename ArgType, typename KernelType,
203-
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
204-
KernelType, ArgType>()> * = nullptr>
205-
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
206-
kernel_handler KH;
207-
KernelName(Arg, KH);
187+
template <typename KernelType> void runKernelWithoutArg(KernelType KernelName) {
188+
if constexpr (isKernelLambdaCallableWithKernelHandler<KernelType>()) {
189+
kernel_handler KH;
190+
KernelName(KH);
191+
} else {
192+
KernelName();
193+
}
208194
}
209195

210-
template <typename ArgType, typename KernelType,
211-
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
212-
KernelType, ArgType>()> * = nullptr>
196+
template <typename ArgType, typename KernelType>
213197
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
214-
KernelName(Arg);
198+
if constexpr (isKernelLambdaCallableWithKernelHandler<KernelType,
199+
ArgType>()) {
200+
kernel_handler KH;
201+
KernelName(Arg, KH);
202+
} else {
203+
KernelName(Arg);
204+
}
215205
}
216206

217207
// The pure virtual class aimed to store lambda/functors of any type.

sycl/include/CL/sycl/handler.hpp

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -968,80 +968,57 @@ class __SYCL_EXPORT handler {
968968
// Wrappers for kernel_single_task(...)
969969

970970
template <typename KernelName, typename KernelType>
971-
std::enable_if_t<
972-
detail::isKernelLambdaCallableWithKernelHandler<KernelType>(), void>
973-
#ifdef __SYCL_NONCONST_FUNCTOR__
974-
kernel_single_task_wrapper(KernelType KernelFunc) {
975-
#else
976-
kernel_single_task_wrapper(const KernelType &KernelFunc) {
977-
#endif
978-
kernel_handler KH;
979-
kernel_single_task<KernelName>(KernelFunc, KH);
980-
}
981-
982-
template <typename KernelName, typename KernelType>
983-
std::enable_if_t<
984-
!detail::isKernelLambdaCallableWithKernelHandler<KernelType>(), void>
971+
void
985972
#ifdef __SYCL_NONCONST_FUNCTOR__
986973
kernel_single_task_wrapper(KernelType KernelFunc) {
987974
#else
988975
kernel_single_task_wrapper(const KernelType &KernelFunc) {
989976
#endif
990-
kernel_single_task<KernelName>(KernelFunc);
977+
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
978+
KernelType>()) {
979+
kernel_handler KH;
980+
kernel_single_task<KernelName>(KernelFunc, KH);
981+
} else {
982+
kernel_single_task<KernelName>(KernelFunc);
983+
}
991984
}
992985

993986
// Wrappers for kernel_parallel_for(...)
994987

995988
template <typename KernelName, typename ElementType, typename KernelType>
996-
std::enable_if_t<detail::isKernelLambdaCallableWithKernelHandler<
997-
KernelType, ElementType>(),
998-
void>
999-
#ifdef __SYCL_NONCONST_FUNCTOR__
1000-
kernel_parallel_for_wrapper(KernelType KernelFunc) {
1001-
#else
1002-
kernel_parallel_for_wrapper(const KernelType &KernelFunc) {
1003-
#endif
1004-
kernel_handler KH;
1005-
kernel_parallel_for<KernelName, ElementType>(KernelFunc, KH);
1006-
}
1007-
1008-
template <typename KernelName, typename ElementType, typename KernelType>
1009-
std::enable_if_t<!detail::isKernelLambdaCallableWithKernelHandler<
1010-
KernelType, ElementType>(),
1011-
void>
989+
void
1012990
#ifdef __SYCL_NONCONST_FUNCTOR__
1013991
kernel_parallel_for_wrapper(KernelType KernelFunc) {
1014992
#else
1015993
kernel_parallel_for_wrapper(const KernelType &KernelFunc) {
1016994
#endif
1017-
kernel_parallel_for<KernelName, ElementType>(KernelFunc);
995+
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
996+
KernelType, ElementType>()) {
997+
kernel_handler KH;
998+
kernel_parallel_for<KernelName, ElementType>(KernelFunc, KH);
999+
}
1000+
else {
1001+
kernel_parallel_for<KernelName, ElementType>(KernelFunc);
1002+
}
10181003
}
10191004

10201005
// Wrappers for kernel_parallel_for_work_group(...)
10211006

10221007
template <typename KernelName, typename ElementType, typename KernelType>
1023-
std::enable_if_t<detail::isKernelLambdaCallableWithKernelHandler<
1024-
KernelType, ElementType>(),
1025-
void>
1026-
#ifdef __SYCL_NONCONST_FUNCTOR__
1027-
kernel_parallel_for_work_group_wrapper(KernelType KernelFunc) {
1028-
#else
1029-
kernel_parallel_for_work_group_wrapper(const KernelType &KernelFunc) {
1030-
#endif
1031-
kernel_handler KH;
1032-
kernel_parallel_for_work_group<KernelName, ElementType>(KernelFunc, KH);
1033-
}
1034-
1035-
template <typename KernelName, typename ElementType, typename KernelType>
1036-
std::enable_if_t<!detail::isKernelLambdaCallableWithKernelHandler<
1037-
KernelType, ElementType>(),
1038-
void>
1008+
void
10391009
#ifdef __SYCL_NONCONST_FUNCTOR__
10401010
kernel_parallel_for_work_group_wrapper(KernelType KernelFunc) {
10411011
#else
10421012
kernel_parallel_for_work_group_wrapper(const KernelType &KernelFunc) {
10431013
#endif
1044-
kernel_parallel_for_work_group<KernelName, ElementType>(KernelFunc);
1014+
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
1015+
KernelType, ElementType>()) {
1016+
kernel_handler KH;
1017+
kernel_parallel_for_work_group<KernelName, ElementType>(KernelFunc, KH);
1018+
}
1019+
else {
1020+
kernel_parallel_for_work_group<KernelName, ElementType>(KernelFunc);
1021+
}
10451022
}
10461023

10471024
std::shared_ptr<detail::kernel_bundle_impl>
@@ -2289,32 +2266,25 @@ class __SYCL_EXPORT handler {
22892266

22902267
friend class ::MockHandler;
22912268

2292-
template <
2293-
typename TransformedArgType, int Dims, typename KernelType,
2294-
typename std::enable_if_t<detail::isKernelLambdaCallableWithKernelHandler<
2295-
KernelType, TransformedArgType>()> * = nullptr>
2269+
template <typename TransformedArgType, int Dims, typename KernelType>
22962270
auto getRangeRoundedKernelLambda(KernelType KernelFunc,
22972271
range<Dims> NumWorkItems) {
2298-
return [=](TransformedArgType Arg, kernel_handler KH) {
2299-
if (Arg[0] >= NumWorkItems[0])
2300-
return;
2301-
Arg.set_allowed_range(NumWorkItems);
2302-
KernelFunc(Arg, KH);
2303-
};
2304-
}
2305-
2306-
template <typename TransformedArgType, int Dims, typename KernelType,
2307-
typename std::enable_if_t<
2308-
!detail::isKernelLambdaCallableWithKernelHandler<
2309-
KernelType, TransformedArgType>()> * = nullptr>
2310-
auto getRangeRoundedKernelLambda(KernelType KernelFunc,
2311-
range<Dims> NumWorkItems) {
2312-
return [=](TransformedArgType Arg) {
2313-
if (Arg[0] >= NumWorkItems[0])
2314-
return;
2315-
Arg.set_allowed_range(NumWorkItems);
2316-
KernelFunc(Arg);
2317-
};
2272+
if constexpr (detail::isKernelLambdaCallableWithKernelHandler<
2273+
KernelType, TransformedArgType>()) {
2274+
return [=](TransformedArgType Arg, kernel_handler KH) {
2275+
if (Arg[0] >= NumWorkItems[0])
2276+
return;
2277+
Arg.set_allowed_range(NumWorkItems);
2278+
KernelFunc(Arg, KH);
2279+
};
2280+
} else {
2281+
return [=](TransformedArgType Arg) {
2282+
if (Arg[0] >= NumWorkItems[0])
2283+
return;
2284+
Arg.set_allowed_range(NumWorkItems);
2285+
KernelFunc(Arg);
2286+
};
2287+
}
23182288
}
23192289
};
23202290
} // namespace sycl

0 commit comments

Comments
 (0)