Skip to content

Commit f189e41

Browse files
authored
[SYCL][CUDA] Add sub-group shuffles (#2623)
Sub-group shuffles map to one of the following intrinsics: - __nvvm_shfl_sync_idx_i32 - __nvvm_shfl_sync_up_i32 - __nvvm_shfl_sync_down_i32 - __nvvm_shfl_sync_xor_i32 Implemented in the SYCL headers instead of libclc for two reasons: 1) The SPIR-V implementation uses an extension (__spirv_SubgroupShuffleINTEL) 2) We currently need to use enable_if to generate different instruction sequences for some types, and these cases differ between SPIR-V/PTX. Signed-off-by: John Pennycook <[email protected]>
1 parent 8471e7a commit f189e41

File tree

5 files changed

+155
-41
lines changed

5 files changed

+155
-41
lines changed

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,41 @@ template <> struct group_scope<::cl::sycl::ONEAPI::sub_group> {
3535
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
3636
};
3737

38-
// Generic shuffles and broadcasts may require multiple calls to SPIR-V
38+
// Generic shuffles and broadcasts may require multiple calls to
3939
// intrinsics, and should use the fewest broadcasts possible
40-
// - Loop over 64-bit chunks until remaining bytes < 64-bit
40+
// - Loop over chunks until remaining bytes < chunk size
4141
// - At most one 32-bit, 16-bit and 8-bit chunk left over
42+
#ifndef __NVPTX__
43+
using ShuffleChunkT = uint64_t;
44+
#else
45+
using ShuffleChunkT = uint32_t;
46+
#endif
4247
template <typename T, typename Functor>
4348
void GenericCall(const Functor &ApplyToBytes) {
44-
if (sizeof(T) >= sizeof(uint64_t)) {
49+
if (sizeof(T) >= sizeof(ShuffleChunkT)) {
4550
#pragma unroll
46-
for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) {
47-
ApplyToBytes(Offset, sizeof(uint64_t));
51+
for (size_t Offset = 0; Offset < sizeof(T);
52+
Offset += sizeof(ShuffleChunkT)) {
53+
ApplyToBytes(Offset, sizeof(ShuffleChunkT));
4854
}
4955
}
50-
if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
51-
size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
52-
ApplyToBytes(Offset, sizeof(uint32_t));
56+
if (sizeof(ShuffleChunkT) >= sizeof(uint64_t)) {
57+
if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
58+
size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
59+
ApplyToBytes(Offset, sizeof(uint32_t));
60+
}
5361
}
54-
if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
55-
size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
56-
ApplyToBytes(Offset, sizeof(uint16_t));
62+
if (sizeof(ShuffleChunkT) >= sizeof(uint32_t)) {
63+
if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
64+
size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
65+
ApplyToBytes(Offset, sizeof(uint16_t));
66+
}
5767
}
58-
if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
59-
size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
60-
ApplyToBytes(Offset, sizeof(uint8_t));
68+
if (sizeof(ShuffleChunkT) >= sizeof(uint16_t)) {
69+
if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
70+
size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
71+
ApplyToBytes(Offset, sizeof(uint8_t));
72+
}
6173
}
6274
}
6375

@@ -423,48 +435,134 @@ AtomicMax(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
423435
return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
424436
}
425437

426-
// Native shuffles map directly to a SPIR-V SubgroupShuffle intrinsic
438+
// Native shuffles map directly to a shuffle intrinsic:
439+
// - The Intel SPIR-V extension natively supports all arithmetic types
440+
// - The CUDA shfl intrinsics do not support vectors, and we use the _i32
441+
// variants for all scalar types
442+
#ifndef __NVPTX__
427443
template <typename T>
428444
using EnableIfNativeShuffle =
429445
detail::enable_if_t<detail::is_arithmetic<T>::value, T>;
446+
#else
447+
template <typename T>
448+
using EnableIfNativeShuffle = detail::enable_if_t<
449+
std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t)), T>;
450+
451+
template <typename T>
452+
using EnableIfVectorShuffle =
453+
detail::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
454+
#endif
455+
456+
#ifdef __NVPTX__
457+
inline uint32_t membermask() {
458+
uint32_t FULL_MASK = 0xFFFFFFFF;
459+
uint32_t max_size = __spirv_SubgroupMaxSize();
460+
uint32_t sg_size = __spirv_SubgroupSize();
461+
return FULL_MASK >> (max_size - sg_size);
462+
}
463+
#endif
430464

431465
template <typename T>
432466
EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
467+
#ifndef __NVPTX__
433468
using OCLT = detail::ConvertToOpenCLType_t<T>;
434469
return __spirv_SubgroupShuffleINTEL(OCLT(x),
435470
static_cast<uint32_t>(local_id.get(0)));
471+
#else
472+
return __nvvm_shfl_sync_idx_i32(membermask(), x, local_id.get(0), 0x1f);
473+
#endif
436474
}
437475

438476
template <typename T>
439477
EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
478+
#ifndef __NVPTX__
440479
using OCLT = detail::ConvertToOpenCLType_t<T>;
441480
return __spirv_SubgroupShuffleXorINTEL(
442481
OCLT(x), static_cast<uint32_t>(local_id.get(0)));
482+
#else
483+
return __nvvm_shfl_sync_bfly_i32(membermask(), x, local_id.get(0), 0x1f);
484+
#endif
443485
}
444486

445487
template <typename T>
446488
EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
489+
#ifndef __NVPTX__
447490
using OCLT = detail::ConvertToOpenCLType_t<T>;
448491
return __spirv_SubgroupShuffleDownINTEL(
449492
OCLT(x), OCLT(x), static_cast<uint32_t>(local_id.get(0)));
493+
#else
494+
return __nvvm_shfl_sync_down_i32(membermask(), x, local_id.get(0), 0x1f);
495+
#endif
450496
}
451497

452498
template <typename T>
453499
EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
500+
#ifndef __NVPTX__
454501
using OCLT = detail::ConvertToOpenCLType_t<T>;
455502
return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x),
456503
static_cast<uint32_t>(local_id.get(0)));
504+
#else
505+
return __nvvm_shfl_sync_up_i32(membermask(), x, local_id.get(0), 0);
506+
#endif
457507
}
458508

459-
// Bitcast shuffles can be implemented using a single SPIR-V SubgroupShuffle
509+
#ifdef __NVPTX__
510+
template <typename T>
511+
EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
512+
T result;
513+
for (int s = 0; s < x.get_size(); ++s) {
514+
result[s] = SubgroupShuffle(x[s], local_id);
515+
}
516+
return result;
517+
}
518+
519+
template <typename T>
520+
EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
521+
T result;
522+
for (int s = 0; s < x.get_size(); ++s) {
523+
result[s] = SubgroupShuffleXor(x[s], local_id);
524+
}
525+
return result;
526+
}
527+
528+
template <typename T>
529+
EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
530+
T result;
531+
for (int s = 0; s < x.get_size(); ++s) {
532+
result[s] = SubgroupShuffleDown(x[s], local_id);
533+
}
534+
return result;
535+
}
536+
537+
template <typename T>
538+
EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
539+
T result;
540+
for (int s = 0; s < x.get_size(); ++s) {
541+
result[s] = SubgroupShuffleUp(x[s], local_id);
542+
}
543+
return result;
544+
}
545+
#endif
546+
547+
// Bitcast shuffles can be implemented using a single SubgroupShuffle
460548
// intrinsic, but require type-punning via an appropriate integer type
549+
#ifndef __NVPTX__
461550
template <typename T>
462551
using EnableIfBitcastShuffle =
463552
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
464553
(std::is_trivially_copyable<T>::value &&
465554
(sizeof(T) == 1 || sizeof(T) == 2 ||
466555
sizeof(T) == 4 || sizeof(T) == 8)),
467556
T>;
557+
#else
558+
template <typename T>
559+
using EnableIfBitcastShuffle = detail::enable_if_t<
560+
!(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
561+
!detail::is_vector_arithmetic<T>::value &&
562+
(std::is_trivially_copyable<T>::value &&
563+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
564+
T>;
565+
#endif
468566

469567
template <typename T>
470568
using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
@@ -473,57 +571,87 @@ template <typename T>
473571
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
474572
using ShuffleT = ConvertToNativeShuffleType_t<T>;
475573
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
574+
#ifndef __NVPTX__
476575
ShuffleT Result = __spirv_SubgroupShuffleINTEL(
477576
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
577+
#else
578+
ShuffleT Result =
579+
__nvvm_shfl_sync_idx_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
580+
#endif
478581
return detail::bit_cast<T>(Result);
479582
}
480583

481584
template <typename T>
482585
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
483586
using ShuffleT = ConvertToNativeShuffleType_t<T>;
484587
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
588+
#ifndef __NVPTX__
485589
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
486590
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
591+
#else
592+
ShuffleT Result =
593+
__nvvm_shfl_sync_bfly_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
594+
#endif
487595
return detail::bit_cast<T>(Result);
488596
}
489597

490598
template <typename T>
491599
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
492600
using ShuffleT = ConvertToNativeShuffleType_t<T>;
493601
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
602+
#ifndef __NVPTX__
494603
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(
495604
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
605+
#else
606+
ShuffleT Result =
607+
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
608+
#endif
496609
return detail::bit_cast<T>(Result);
497610
}
498611

499612
template <typename T>
500613
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
501614
using ShuffleT = ConvertToNativeShuffleType_t<T>;
502615
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
616+
#ifndef __NVPTX__
503617
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(
504618
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
619+
#else
620+
ShuffleT Result =
621+
__nvvm_shfl_sync_up_i32(membermask(), ShuffleX, local_id.get(0), 0);
622+
#endif
505623
return detail::bit_cast<T>(Result);
506624
}
507625

508-
// Generic shuffles may require multiple calls to SPIR-V SubgroupShuffle
626+
// Generic shuffles may require multiple calls to SubgroupShuffle
509627
// intrinsics, and should use the fewest shuffles possible:
510628
// - Loop over 64-bit chunks until remaining bytes < 64-bit
511629
// - At most one 32-bit, 16-bit and 8-bit chunk left over
630+
#ifndef __NVPTX__
512631
template <typename T>
513632
using EnableIfGenericShuffle =
514633
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
515634
!(std::is_trivially_copyable<T>::value &&
516635
(sizeof(T) == 1 || sizeof(T) == 2 ||
517636
sizeof(T) == 4 || sizeof(T) == 8)),
518637
T>;
638+
#else
639+
template <typename T>
640+
using EnableIfGenericShuffle = detail::enable_if_t<
641+
!(std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t))) &&
642+
!detail::is_vector_arithmetic<T>::value &&
643+
!(std::is_trivially_copyable<T>::value &&
644+
(sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)),
645+
T>;
646+
#endif
519647

520648
template <typename T>
521649
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
522650
T Result;
523651
char *XBytes = reinterpret_cast<char *>(&x);
524652
char *ResultBytes = reinterpret_cast<char *>(&Result);
525653
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
526-
uint64_t ShuffleX, ShuffleResult;
654+
ShuffleChunkT ShuffleX, ShuffleResult;
527655
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
528656
ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
529657
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
@@ -538,7 +666,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
538666
char *XBytes = reinterpret_cast<char *>(&x);
539667
char *ResultBytes = reinterpret_cast<char *>(&Result);
540668
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
541-
uint64_t ShuffleX, ShuffleResult;
669+
ShuffleChunkT ShuffleX, ShuffleResult;
542670
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
543671
ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
544672
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
@@ -553,7 +681,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
553681
char *XBytes = reinterpret_cast<char *>(&x);
554682
char *ResultBytes = reinterpret_cast<char *>(&Result);
555683
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
556-
uint64_t ShuffleX, ShuffleResult;
684+
ShuffleChunkT ShuffleX, ShuffleResult;
557685
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
558686
ShuffleResult = SubgroupShuffleDown(ShuffleX, local_id);
559687
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
@@ -568,7 +696,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
568696
char *XBytes = reinterpret_cast<char *>(&x);
569697
char *ResultBytes = reinterpret_cast<char *>(&Result);
570698
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
571-
uint64_t ShuffleX, ShuffleResult;
699+
ShuffleChunkT ShuffleX, ShuffleResult;
572700
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
573701
ShuffleResult = SubgroupShuffleUp(ShuffleX, local_id);
574702
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);

sycl/test/sub_group/generic-shuffle.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// UNSUPPORTED: cuda
2-
// CUDA compilation and runtime do not yet support sub-groups.
3-
//
41
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
52
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
63
// RUN: %CPU_RUN_PLACEHOLDER %t.out
@@ -216,7 +213,7 @@ void check_struct(queue &Queue, Generator &Gen, size_t G = 256, size_t L = 64) {
216213

217214
int main() {
218215
queue Queue;
219-
if (!Queue.get_device().has_extension("cl_intel_subgroups")) {
216+
if (Queue.get_device().is_host()) {
220217
std::cout << "Skipping test\n";
221218
return 0;
222219
}

sycl/test/sub_group/shuffle.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// UNSUPPORTED: cuda
2-
// CUDA compilation and runtime do not yet support sub-groups.
3-
//
41
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
52
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
63
// RUN: %CPU_RUN_PLACEHOLDER %t.out
@@ -19,14 +16,12 @@
1916

2017
int main() {
2118
queue Queue;
22-
if (!Queue.get_device().has_extension("cl_intel_subgroups")) {
19+
if (Queue.get_device().is_host()) {
2320
std::cout << "Skipping test\n";
2421
return 0;
2522
}
26-
if (Queue.get_device().has_extension("cl_intel_subgroups_short")) {
27-
check<short>(Queue);
28-
check<unsigned short>(Queue);
29-
}
23+
check<short>(Queue);
24+
check<unsigned short>(Queue);
3025
check<int>(Queue);
3126
check<int, 2>(Queue);
3227
check<int, 4>(Queue);

sycl/test/sub_group/shuffle_fp16.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// UNSUPPORTED: cuda
2-
// CUDA compilation and runtime do not yet support sub-groups.
3-
//
41
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
52
// RUN: %GPU_RUN_PLACEHOLDER %t.out
63
//
@@ -16,7 +13,7 @@
1613

1714
int main() {
1815
queue Queue;
19-
if (!Queue.get_device().has_extension("cl_intel_subgroups")) {
16+
if (Queue.get_device().is_host()) {
2017
std::cout << "Skipping test\n";
2118
return 0;
2219
}

sycl/test/sub_group/shuffle_fp64.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// UNSUPPORTED: cuda
2-
// CUDA compilation and runtime do not yet support sub-groups.
3-
//
41
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
52
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
63
// RUN: %CPU_RUN_PLACEHOLDER %t.out
@@ -19,7 +16,7 @@
1916

2017
int main() {
2118
queue Queue;
22-
if (!Queue.get_device().has_extension("cl_intel_subgroups")) {
19+
if (Queue.get_device().is_host()) {
2320
std::cout << "Skipping test\n";
2421
return 0;
2522
}

0 commit comments

Comments
 (0)