Skip to content

Commit 84fe658

Browse files
authored
[SYCL] Add sub-group functions emulation for vector of doubles. (#8252)
intel/llvm-test-suite#1603
1 parent 039b538 commit 84fe658

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -498,22 +498,44 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
498498
}
499499

500500
// Native shuffles map directly to a shuffle intrinsic:
501-
// - The Intel SPIR-V extension natively supports all arithmetic types
501+
// - The Intel SPIR-V extension natively supports all arithmetic types.
502+
// However, OpenCL extension natively supports float vectors,
503+
// integer vectors, half scalar and double scalar.
504+
// For double vectors we perform emulation with scalar version.
502505
// - The CUDA shfl intrinsics do not support vectors, and we use the _i32
503506
// variants for all scalar types
504507
#ifndef __NVPTX__
508+
509+
template <typename T>
510+
struct TypeIsProhibitedForShuffleEmulation
511+
: bool_constant<std::is_same_v<vector_element_t<T>, double>> {};
512+
513+
template <typename T>
514+
struct VecTypeIsProhibitedForShuffleEmulation
515+
: bool_constant<
516+
(detail::get_vec_size<T>::size > 1) &&
517+
TypeIsProhibitedForShuffleEmulation<vector_element_t<T>>::value> {};
518+
505519
template <typename T>
506520
using EnableIfNativeShuffle =
507-
detail::enable_if_t<detail::is_arithmetic<T>::value, T>;
508-
#else
521+
std::enable_if_t<detail::is_arithmetic<T>::value &&
522+
!VecTypeIsProhibitedForShuffleEmulation<T>::value,
523+
T>;
524+
509525
template <typename T>
510-
using EnableIfNativeShuffle = detail::enable_if_t<
526+
using EnableIfVectorShuffle =
527+
std::enable_if_t<VecTypeIsProhibitedForShuffleEmulation<T>::value, T>;
528+
529+
#else // ifndef __NVPTX__
530+
531+
template <typename T>
532+
using EnableIfNativeShuffle = std::enable_if_t<
511533
std::is_integral<T>::value && (sizeof(T) <= sizeof(int32_t)), T>;
512534

513535
template <typename T>
514536
using EnableIfVectorShuffle =
515-
detail::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
516-
#endif
537+
std::enable_if_t<detail::is_vector_arithmetic<T>::value, T>;
538+
#endif // ifndef __NVPTX__
517539

518540
#ifdef __NVPTX__
519541
inline uint32_t membermask() {
@@ -565,7 +587,6 @@ EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
565587
#endif
566588
}
567589

568-
#ifdef __NVPTX__
569590
template <typename T>
570591
EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
571592
T result;
@@ -601,7 +622,6 @@ EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
601622
}
602623
return result;
603624
}
604-
#endif
605625

606626
// Bitcast shuffles can be implemented using a single SubgroupShuffle
607627
// intrinsic, but require type-punning via an appropriate integer type

sycl/include/sycl/detail/type_traits.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ template <typename> struct is_vec : std::false_type {};
231231
template <typename T, std::size_t N>
232232
struct is_vec<sycl::vec<T, N>> : std::true_type {};
233233

234+
template <typename> struct get_vec_size {
235+
static constexpr std::size_t size = 1;
236+
};
237+
238+
template <typename T, std::size_t N> struct get_vec_size<sycl::vec<T, N>> {
239+
static constexpr std::size_t size = N;
240+
};
241+
234242
// is_integral
235243
template <typename T>
236244
struct is_integral : std::is_integral<vector_element_t<T>> {};

0 commit comments

Comments
 (0)