@@ -498,22 +498,44 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
498
498
}
499
499
500
500
// Native shuffles map directly to a shuffle intrinsic:
501
- // - The Intel SPIR-V extension natively supports all arithmetic types
501
+ // - The Intel SPIR-V extension natively supports all arithmetic types.
502
+ // However, OpenCL extension natively supports float vectors,
503
+ // integer vectors, half scalar and double scalar.
504
+ // For double vectors we perform emulation with scalar version.
502
505
// - The CUDA shfl intrinsics do not support vectors, and we use the _i32
503
506
// variants for all scalar types
504
507
#ifndef __NVPTX__
508
+
509
+ template <typename T>
510
+ struct TypeIsProhibitedForShuffleEmulation
511
+ : bool_constant<std::is_same_v<vector_element_t <T>, double >> {};
512
+
513
+ template <typename T>
514
+ struct VecTypeIsProhibitedForShuffleEmulation
515
+ : bool_constant<
516
+ (detail::get_vec_size<T>::size > 1 ) &&
517
+ TypeIsProhibitedForShuffleEmulation<vector_element_t <T>>::value> {};
518
+
505
519
template <typename T>
506
520
using EnableIfNativeShuffle =
507
- detail::enable_if_t <detail::is_arithmetic<T>::value, T>;
508
- #else
521
+ std::enable_if_t <detail::is_arithmetic<T>::value &&
522
+ !VecTypeIsProhibitedForShuffleEmulation<T>::value,
523
+ T>;
524
+
509
525
template <typename T>
510
- using EnableIfNativeShuffle = detail::enable_if_t <
526
+ using EnableIfVectorShuffle =
527
+ std::enable_if_t <VecTypeIsProhibitedForShuffleEmulation<T>::value, T>;
528
+
529
+ #else // ifndef __NVPTX__
530
+
531
+ template <typename T>
532
+ using EnableIfNativeShuffle = std::enable_if_t <
511
533
std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t )), T>;
512
534
513
535
template <typename T>
514
536
using EnableIfVectorShuffle =
515
- detail ::enable_if_t <detail::is_vector_arithmetic<T>::value, T>;
516
- #endif
537
+ std ::enable_if_t <detail::is_vector_arithmetic<T>::value, T>;
538
+ #endif // ifndef __NVPTX__
517
539
518
540
#ifdef __NVPTX__
519
541
inline uint32_t membermask () {
@@ -565,7 +587,6 @@ EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
565
587
#endif
566
588
}
567
589
568
- #ifdef __NVPTX__
569
590
template <typename T>
570
591
EnableIfVectorShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
571
592
T result;
@@ -601,7 +622,6 @@ EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
601
622
}
602
623
return result;
603
624
}
604
- #endif
605
625
606
626
// Bitcast shuffles can be implemented using a single SubgroupShuffle
607
627
// intrinsic, but require type-punning via an appropriate integer type
0 commit comments