Skip to content

Commit 3205368

Browse files
[SYCL] group algorithm routines with broadened supported types (#4910)
Signed-off-by: Chris Perkins <[email protected]>
1 parent 43f2921 commit 3205368

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,12 @@ joint_none_of(Group g, Ptr first, Ptr last, Predicate pred) {
386386
}
387387

388388
// ---- shift_group_left
389+
// TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
390+
// copyable.
389391
template <typename Group, typename T>
390392
detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
391-
detail::is_arithmetic<T>::value),
393+
(std::is_trivially_copyable<T>::value ||
394+
detail::is_vec<T>::value)),
392395
T>
393396
shift_group_left(Group, T x, typename Group::linear_id_type delta = 1) {
394397
#ifdef __SYCL_DEVICE_ONLY__
@@ -402,9 +405,12 @@ shift_group_left(Group, T x, typename Group::linear_id_type delta = 1) {
402405
}
403406

404407
// ---- shift_group_right
408+
// TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
409+
// copyable.
405410
template <typename Group, typename T>
406411
detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
407-
detail::is_arithmetic<T>::value),
412+
(std::is_trivially_copyable<T>::value ||
413+
detail::is_vec<T>::value)),
408414
T>
409415
shift_group_right(Group, T x, typename Group::linear_id_type delta = 1) {
410416
#ifdef __SYCL_DEVICE_ONLY__
@@ -418,9 +424,12 @@ shift_group_right(Group, T x, typename Group::linear_id_type delta = 1) {
418424
}
419425

420426
// ---- permute_group_by_xor
427+
// TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
428+
// copyable.
421429
template <typename Group, typename T>
422430
detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
423-
detail::is_arithmetic<T>::value),
431+
(std::is_trivially_copyable<T>::value ||
432+
detail::is_vec<T>::value)),
424433
T>
425434
permute_group_by_xor(Group, T x, typename Group::linear_id_type mask) {
426435
#ifdef __SYCL_DEVICE_ONLY__
@@ -434,9 +443,12 @@ permute_group_by_xor(Group, T x, typename Group::linear_id_type mask) {
434443
}
435444

436445
// ---- select_from_group
446+
// TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
447+
// copyable.
437448
template <typename Group, typename T>
438449
detail::enable_if_t<(std::is_same<std::decay_t<Group>, sub_group>::value &&
439-
detail::is_arithmetic<T>::value),
450+
(std::is_trivially_copyable<T>::value ||
451+
detail::is_vec<T>::value)),
440452
T>
441453
select_from_group(Group, T x, typename Group::id_type local_id) {
442454
#ifdef __SYCL_DEVICE_ONLY__
@@ -450,9 +462,12 @@ select_from_group(Group, T x, typename Group::id_type local_id) {
450462
}
451463

452464
// ---- group_broadcast
465+
// TODO: remove check for detail::is_vec<T> once sycl::vec is trivially
466+
// copyable.
453467
template <typename Group, typename T>
454468
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
455-
detail::is_scalar_arithmetic<T>::value),
469+
(std::is_trivially_copyable<T>::value ||
470+
detail::is_vec<T>::value)),
456471
T>
457472
group_broadcast(Group, T x, typename Group::id_type local_id) {
458473
#ifdef __SYCL_DEVICE_ONLY__
@@ -467,7 +482,8 @@ group_broadcast(Group, T x, typename Group::id_type local_id) {
467482

468483
template <typename Group, typename T>
469484
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
470-
detail::is_scalar_arithmetic<T>::value),
485+
(std::is_trivially_copyable<T>::value ||
486+
detail::is_vec<T>::value)),
471487
T>
472488
group_broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
473489
#ifdef __SYCL_DEVICE_ONLY__
@@ -485,7 +501,8 @@ group_broadcast(Group g, T x, typename Group::linear_id_type linear_local_id) {
485501

486502
template <typename Group, typename T>
487503
detail::enable_if_t<(is_group_v<std::decay_t<Group>> &&
488-
detail::is_scalar_arithmetic<T>::value),
504+
(std::is_trivially_copyable<T>::value ||
505+
detail::is_vec<T>::value)),
489506
T>
490507
group_broadcast(Group g, T x) {
491508
#ifdef __SYCL_DEVICE_ONLY__

0 commit comments

Comments
 (0)