Skip to content

Commit 8382e58

Browse files
authored
[SYCL] Add forward declarations for Subgroup Shuffle functions (#8450)
The main reason for the change is being able to call Generic overloading of Shuffle from Vector overloading of Shuffle.
1 parent 7c9bd09 commit 8382e58

File tree

1 file changed

+68
-42
lines changed

1 file changed

+68
-42
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,49 @@ using EnableIfVectorShuffle =
537537
std::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
538538
#endif // ifndef __NVPTX__
539539

540+
// Bitcast shuffles can be implemented using a single SubgroupShuffle
541+
// intrinsic, but require type-punning via an appropriate integer type
542+
#ifndef __NVPTX__
543+
template <typename T>
544+
using EnableIfBitcastShuffle =
545+
std::enable_if_t<!detail::is_arithmetic<T>::value &&
546+
(std::is_trivially_copyable_v<T> &&
547+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 ||
548+
sizeof(T) == 8)),
549+
T>;
550+
#else
551+
template <typename T>
552+
using EnableIfBitcastShuffle =
553+
std::enable_if_t<!(std::is_integral_v<T> &&
554+
(sizeof(T) <= sizeof(int32_t))) &&
555+
!detail::is_vector_arithmetic<T>::value &&
556+
(std::is_trivially_copyable_v<T> &&
557+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
558+
T>;
559+
#endif // ifndef __NVPTX__
560+
561+
// Generic shuffles may require multiple calls to SubgroupShuffle
562+
// intrinsics, and should use the fewest shuffles possible:
563+
// - Loop over 64-bit chunks until remaining bytes < 64-bit
564+
// - At most one 32-bit, 16-bit and 8-bit chunk left over
565+
#ifndef __NVPTX__
566+
template <typename T>
567+
using EnableIfGenericShuffle =
568+
std::enable_if_t<!detail::is_arithmetic<T>::value &&
569+
!(std::is_trivially_copyable_v<T> &&
570+
(sizeof(T) == 1 || sizeof(T) == 2 ||
571+
sizeof(T) == 4 || sizeof(T) == 8)),
572+
T>;
573+
#else
574+
template <typename T>
575+
using EnableIfGenericShuffle = std::enable_if_t<
576+
!(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
577+
!detail::is_vector_arithmetic<T>::value &&
578+
!(std::is_trivially_copyable_v<T> &&
579+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
580+
T>;
581+
#endif
582+
540583
#ifdef __NVPTX__
541584
inline uint32_t membermask() {
542585
// use a full mask as sync operations are required to be convergent and exited
@@ -545,6 +588,31 @@ inline uint32_t membermask() {
545588
}
546589
#endif
547590

591+
// Forward declarations for template overloadings
592+
template <typename T>
593+
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id);
594+
595+
template <typename T>
596+
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
597+
598+
template <typename T>
599+
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
600+
601+
template <typename T>
602+
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
603+
604+
template <typename T>
605+
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id);
606+
607+
template <typename T>
608+
EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
609+
610+
template <typename T>
611+
EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id);
612+
613+
template <typename T>
614+
EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id);
615+
548616
template <typename T>
549617
EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
550618
#ifndef __NVPTX__
@@ -623,26 +691,6 @@ EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
623691
return result;
624692
}
625693

626-
// Bitcast shuffles can be implemented using a single SubgroupShuffle
627-
// intrinsic, but require type-punning via an appropriate integer type
628-
#ifndef __NVPTX__
629-
template <typename T>
630-
using EnableIfBitcastShuffle =
631-
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
632-
(std::is_trivially_copyable<T>::value &&
633-
(sizeof(T) == 1 || sizeof(T) == 2 ||
634-
sizeof(T) == 4 || sizeof(T) == 8)),
635-
T>;
636-
#else
637-
template <typename T>
638-
using EnableIfBitcastShuffle = detail::enable_if_t<
639-
!(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
640-
!detail::is_vector_arithmetic<T>::value &&
641-
(std::is_trivially_copyable<T>::value &&
642-
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
643-
T>;
644-
#endif
645-
646694
template <typename T>
647695
using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
648696

@@ -699,28 +747,6 @@ EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
699747
return bit_cast<T>(Result);
700748
}
701749

702-
// Generic shuffles may require multiple calls to SubgroupShuffle
703-
// intrinsics, and should use the fewest shuffles possible:
704-
// - Loop over 64-bit chunks until remaining bytes < 64-bit
705-
// - At most one 32-bit, 16-bit and 8-bit chunk left over
706-
#ifndef __NVPTX__
707-
template <typename T>
708-
using EnableIfGenericShuffle =
709-
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
710-
!(std::is_trivially_copyable<T>::value &&
711-
(sizeof(T) == 1 || sizeof(T) == 2 ||
712-
sizeof(T) == 4 || sizeof(T) == 8)),
713-
T>;
714-
#else
715-
template <typename T>
716-
using EnableIfGenericShuffle = detail::enable_if_t<
717-
!(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
718-
!detail::is_vector_arithmetic<T>::value &&
719-
!(std::is_trivially_copyable<T>::value &&
720-
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
721-
T>;
722-
#endif
723-
724750
template <typename T>
725751
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
726752
T Result;

0 commit comments

Comments
 (0)