Skip to content

Commit eeba879

Browse files
authored
[SYCL] Perform SubgroupShuffle emulation for vectors of long long and half (#9102)
1 parent 379a094 commit eeba879

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstring>
1414
#include <sycl/detail/generic_type_traits.hpp>
1515
#include <sycl/detail/helpers.hpp>
16+
#include <sycl/detail/type_list.hpp>
1617
#include <sycl/detail/type_traits.hpp>
1718
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
1819
#include <sycl/id.hpp>
@@ -696,14 +697,19 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
696697
// - The Intel SPIR-V extension natively supports all arithmetic types.
697698
// However, OpenCL extension natively supports float vectors,
698699
// integer vectors, half scalar and double scalar.
699-
// For double vectors we perform emulation with scalar version.
700+
// For double, long, long long, unsigned long, unsigned long long
701+
// and half vectors we perform emulation with scalar version.
700702
// - The CUDA shfl intrinsics do not support vectors, and we use the _i32
701703
// variants for all scalar types
702704
#ifndef __NVPTX__
703705

706+
using ProhibitedTypesForShuffleEmulation =
707+
type_list<double, long, long long, unsigned long, unsigned long long, half>;
708+
704709
template <typename T>
705710
struct TypeIsProhibitedForShuffleEmulation
706-
: std::bool_constant<std::is_same_v<vector_element_t<T>, double>> {};
711+
: std::bool_constant<is_contained<
712+
vector_element_t<T>, ProhibitedTypesForShuffleEmulation>::value> {};
707713

708714
template <typename T>
709715
struct VecTypeIsProhibitedForShuffleEmulation
@@ -790,6 +796,12 @@ EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id);
790796
template <typename T>
791797
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
792798

799+
template <typename T>
800+
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta);
801+
802+
template <typename T>
803+
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta);
804+
793805
template <typename T>
794806
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id);
795807

sycl/test-e2e/SubGroup/shuffle.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,28 @@ int main() {
3333
check<unsigned int, 8>(Queue);
3434
check<unsigned int, 16>(Queue);
3535
check<long>(Queue);
36+
check<long, 2>(Queue);
37+
check<long, 4>(Queue);
38+
check<long, 8>(Queue);
39+
check<long, 16>(Queue);
3640
check<unsigned long>(Queue);
41+
check<unsigned long, 2>(Queue);
42+
check<unsigned long, 4>(Queue);
43+
check<unsigned long, 8>(Queue);
44+
check<unsigned long, 16>(Queue);
3745
check<float>(Queue);
46+
check<float, 2>(Queue);
47+
check<float, 4>(Queue);
48+
check<float, 8>(Queue);
49+
check<float, 16>(Queue);
50+
51+
// Check long long and unsigned long long because they differ from
52+
// long and unsigned long according to C++ rules even if they have the same
53+
// size at some system.
54+
check<long long>(Queue);
55+
check<long long, 16>(Queue);
56+
check<unsigned long long>(Queue);
57+
check<unsigned long long, 16>(Queue);
3858
std::cout << "Test passed." << std::endl;
3959
return 0;
4060
}

sycl/test-e2e/SubGroup/shuffle_fp16.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ int main() {
2424
queue Queue;
2525
if (Queue.get_device().has(sycl::aspect::fp16)) {
2626
check<half>(Queue);
27+
check<half, 2>(Queue);
28+
check<half, 4>(Queue);
29+
check<half, 8>(Queue);
30+
check<half, 16>(Queue);
2731
std::cout << "Test passed." << std::endl;
2832
} else {
2933
std::cout << "Test skipped because device doesn't support aspect::fp16"

0 commit comments

Comments
 (0)