@@ -537,6 +537,49 @@ using EnableIfVectorShuffle =
537
537
std::enable_if_t <detail::is_vector_arithmetic<T>::value, T>;
538
538
#endif // ifndef __NVPTX__
539
539
540
+ // Bitcast shuffles can be implemented using a single SubgroupShuffle
541
+ // intrinsic, but require type-punning via an appropriate integer type
542
+ #ifndef __NVPTX__
543
+ template <typename T>
544
+ using EnableIfBitcastShuffle =
545
+ std::enable_if_t <!detail::is_arithmetic<T>::value &&
546
+ (std::is_trivially_copyable_v<T> &&
547
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 ||
548
+ sizeof (T) == 8 )),
549
+ T>;
550
+ #else
551
+ template <typename T>
552
+ using EnableIfBitcastShuffle =
553
+ std::enable_if_t <!(std::is_integral_v<T> &&
554
+ (sizeof (T) <= sizeof (int32_t ))) &&
555
+ !detail::is_vector_arithmetic<T>::value &&
556
+ (std::is_trivially_copyable_v<T> &&
557
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
558
+ T>;
559
+ #endif // ifndef __NVPTX__
560
+
561
+ // Generic shuffles may require multiple calls to SubgroupShuffle
562
+ // intrinsics, and should use the fewest shuffles possible:
563
+ // - Loop over 64-bit chunks until remaining bytes < 64-bit
564
+ // - At most one 32-bit, 16-bit and 8-bit chunk left over
565
+ #ifndef __NVPTX__
566
+ template <typename T>
567
+ using EnableIfGenericShuffle =
568
+ std::enable_if_t <!detail::is_arithmetic<T>::value &&
569
+ !(std::is_trivially_copyable_v<T> &&
570
+ (sizeof (T) == 1 || sizeof (T) == 2 ||
571
+ sizeof (T) == 4 || sizeof (T) == 8 )),
572
+ T>;
573
+ #else
574
+ template <typename T>
575
+ using EnableIfGenericShuffle = std::enable_if_t <
576
+ !(std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t ))) &&
577
+ !detail::is_vector_arithmetic<T>::value &&
578
+ !(std::is_trivially_copyable_v<T> &&
579
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
580
+ T>;
581
+ #endif
582
+
540
583
#ifdef __NVPTX__
541
584
inline uint32_t membermask () {
542
585
// use a full mask as sync operations are required to be convergent and exited
@@ -545,6 +588,31 @@ inline uint32_t membermask() {
545
588
}
546
589
#endif
547
590
591
+ // Forward declarations for template overloadings
592
+ template <typename T>
593
+ EnableIfBitcastShuffle<T> SubgroupShuffle (T x, id<1 > local_id);
594
+
595
+ template <typename T>
596
+ EnableIfBitcastShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id);
597
+
598
+ template <typename T>
599
+ EnableIfBitcastShuffle<T> SubgroupShuffleDown (T x, id<1 > local_id);
600
+
601
+ template <typename T>
602
+ EnableIfBitcastShuffle<T> SubgroupShuffleUp (T x, id<1 > local_id);
603
+
604
+ template <typename T>
605
+ EnableIfGenericShuffle<T> SubgroupShuffle (T x, id<1 > local_id);
606
+
607
+ template <typename T>
608
+ EnableIfGenericShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id);
609
+
610
+ template <typename T>
611
+ EnableIfGenericShuffle<T> SubgroupShuffleDown (T x, id<1 > local_id);
612
+
613
+ template <typename T>
614
+ EnableIfGenericShuffle<T> SubgroupShuffleUp (T x, id<1 > local_id);
615
+
548
616
template <typename T>
549
617
EnableIfNativeShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
550
618
#ifndef __NVPTX__
@@ -623,26 +691,6 @@ EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
623
691
return result;
624
692
}
625
693
626
- // Bitcast shuffles can be implemented using a single SubgroupShuffle
627
- // intrinsic, but require type-punning via an appropriate integer type
628
- #ifndef __NVPTX__
629
- template <typename T>
630
- using EnableIfBitcastShuffle =
631
- detail::enable_if_t <!detail::is_arithmetic<T>::value &&
632
- (std::is_trivially_copyable<T>::value &&
633
- (sizeof (T) == 1 || sizeof (T) == 2 ||
634
- sizeof (T) == 4 || sizeof (T) == 8 )),
635
- T>;
636
- #else
637
- template <typename T>
638
- using EnableIfBitcastShuffle = detail::enable_if_t <
639
- !(std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t ))) &&
640
- !detail::is_vector_arithmetic<T>::value &&
641
- (std::is_trivially_copyable<T>::value &&
642
- (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
643
- T>;
644
- #endif
645
-
646
694
template <typename T>
647
695
using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t <T>;
648
696
@@ -699,28 +747,6 @@ EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
699
747
return bit_cast<T>(Result);
700
748
}
701
749
702
- // Generic shuffles may require multiple calls to SubgroupShuffle
703
- // intrinsics, and should use the fewest shuffles possible:
704
- // - Loop over 64-bit chunks until remaining bytes < 64-bit
705
- // - At most one 32-bit, 16-bit and 8-bit chunk left over
706
- #ifndef __NVPTX__
707
- template <typename T>
708
- using EnableIfGenericShuffle =
709
- detail::enable_if_t <!detail::is_arithmetic<T>::value &&
710
- !(std::is_trivially_copyable<T>::value &&
711
- (sizeof (T) == 1 || sizeof (T) == 2 ||
712
- sizeof (T) == 4 || sizeof (T) == 8 )),
713
- T>;
714
- #else
715
- template <typename T>
716
- using EnableIfGenericShuffle = detail::enable_if_t <
717
- !(std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t ))) &&
718
- !detail::is_vector_arithmetic<T>::value &&
719
- !(std::is_trivially_copyable<T>::value &&
720
- (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
721
- T>;
722
- #endif
723
-
724
750
template <typename T>
725
751
EnableIfGenericShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
726
752
T Result;
0 commit comments