|
13 | 13 | #include <cstring>
|
14 | 14 | #include <sycl/detail/generic_type_traits.hpp>
|
15 | 15 | #include <sycl/detail/helpers.hpp>
|
| 16 | +#include <sycl/detail/type_list.hpp> |
16 | 17 | #include <sycl/detail/type_traits.hpp>
|
17 | 18 | #include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
|
18 | 19 | #include <sycl/id.hpp>
|
@@ -696,14 +697,19 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
|
696 | 697 | // - The Intel SPIR-V extension natively supports all arithmetic types.
|
697 | 698 | // However, OpenCL extension natively supports float vectors,
|
698 | 699 | // integer vectors, half scalar and double scalar.
|
699 |
| -// For double vectors we perform emulation with scalar version. |
| 700 | +// For double, long, long long, unsigned long, unsigned long long |
| 701 | +// and half vectors we perform emulation with scalar version. |
700 | 702 | // - The CUDA shfl intrinsics do not support vectors, and we use the _i32
|
701 | 703 | // variants for all scalar types
|
702 | 704 | #ifndef __NVPTX__
|
703 | 705 |
|
| 706 | +using ProhibitedTypesForShuffleEmulation = |
| 707 | + type_list<double, long, long long, unsigned long, unsigned long long, half>; |
| 708 | + |
704 | 709 | template <typename T>
|
705 | 710 | struct TypeIsProhibitedForShuffleEmulation
|
706 |
| - : std::bool_constant<std::is_same_v<vector_element_t<T>, double>> {}; |
| 711 | + : std::bool_constant<is_contained< |
| 712 | + vector_element_t<T>, ProhibitedTypesForShuffleEmulation>::value> {}; |
707 | 713 |
|
708 | 714 | template <typename T>
|
709 | 715 | struct VecTypeIsProhibitedForShuffleEmulation
|
@@ -790,6 +796,12 @@ EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id);
|
790 | 796 | template <typename T>
|
791 | 797 | EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id);
|
792 | 798 |
|
| 799 | +template <typename T> |
| 800 | +EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta); |
| 801 | + |
| 802 | +template <typename T> |
| 803 | +EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta); |
| 804 | + |
793 | 805 | template <typename T>
|
794 | 806 | EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id);
|
795 | 807 |
|
|
0 commit comments