@@ -35,29 +35,41 @@ template <> struct group_scope<::cl::sycl::ONEAPI::sub_group> {
35
35
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
36
36
};
37
37
38
- // Generic shuffles and broadcasts may require multiple calls to SPIR-V
38
+ // Generic shuffles and broadcasts may require multiple calls to
39
39
// intrinsics, and should use the fewest broadcasts possible
40
- // - Loop over 64-bit chunks until remaining bytes < 64-bit
40
+ // - Loop over chunks until remaining bytes < chunk size
41
41
// - At most one 32-bit, 16-bit and 8-bit chunk left over
42
+ #ifndef __NVPTX__
43
+ using ShuffleChunkT = uint64_t ;
44
+ #else
45
+ using ShuffleChunkT = uint32_t ;
46
+ #endif
42
47
template <typename T, typename Functor>
43
48
void GenericCall (const Functor &ApplyToBytes) {
44
- if (sizeof (T) >= sizeof (uint64_t )) {
49
+ if (sizeof (T) >= sizeof (ShuffleChunkT )) {
45
50
#pragma unroll
46
- for (size_t Offset = 0 ; Offset < sizeof (T); Offset += sizeof (uint64_t )) {
47
- ApplyToBytes (Offset, sizeof (uint64_t ));
51
+ for (size_t Offset = 0 ; Offset < sizeof (T);
52
+ Offset += sizeof (ShuffleChunkT)) {
53
+ ApplyToBytes (Offset, sizeof (ShuffleChunkT));
48
54
}
49
55
}
50
- if (sizeof (T) % sizeof (uint64_t ) >= sizeof (uint32_t )) {
51
- size_t Offset = sizeof (T) / sizeof (uint64_t ) * sizeof (uint64_t );
52
- ApplyToBytes (Offset, sizeof (uint32_t ));
56
+ if (sizeof (ShuffleChunkT) >= sizeof (uint64_t )) {
57
+ if (sizeof (T) % sizeof (uint64_t ) >= sizeof (uint32_t )) {
58
+ size_t Offset = sizeof (T) / sizeof (uint64_t ) * sizeof (uint64_t );
59
+ ApplyToBytes (Offset, sizeof (uint32_t ));
60
+ }
53
61
}
54
- if (sizeof (T) % sizeof (uint32_t ) >= sizeof (uint16_t )) {
55
- size_t Offset = sizeof (T) / sizeof (uint32_t ) * sizeof (uint32_t );
56
- ApplyToBytes (Offset, sizeof (uint16_t ));
62
+ if (sizeof (ShuffleChunkT) >= sizeof (uint32_t )) {
63
+ if (sizeof (T) % sizeof (uint32_t ) >= sizeof (uint16_t )) {
64
+ size_t Offset = sizeof (T) / sizeof (uint32_t ) * sizeof (uint32_t );
65
+ ApplyToBytes (Offset, sizeof (uint16_t ));
66
+ }
57
67
}
58
- if (sizeof (T) % sizeof (uint16_t ) >= sizeof (uint8_t )) {
59
- size_t Offset = sizeof (T) / sizeof (uint16_t ) * sizeof (uint16_t );
60
- ApplyToBytes (Offset, sizeof (uint8_t ));
68
+ if (sizeof (ShuffleChunkT) >= sizeof (uint16_t )) {
69
+ if (sizeof (T) % sizeof (uint16_t ) >= sizeof (uint8_t )) {
70
+ size_t Offset = sizeof (T) / sizeof (uint16_t ) * sizeof (uint16_t );
71
+ ApplyToBytes (Offset, sizeof (uint8_t ));
72
+ }
61
73
}
62
74
}
63
75
@@ -423,48 +435,134 @@ AtomicMax(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
423
435
return __spirv_AtomicMax (Ptr, SPIRVScope, SPIRVOrder, Value);
424
436
}
425
437
426
- // Native shuffles map directly to a SPIR-V SubgroupShuffle intrinsic
438
+ // Native shuffles map directly to a shuffle intrinsic:
439
+ // - The Intel SPIR-V extension natively supports all arithmetic types
440
+ // - The CUDA shfl intrinsics do not support vectors, and we use the _i32
441
+ // variants for all scalar types
442
+ #ifndef __NVPTX__
427
443
template <typename T>
428
444
using EnableIfNativeShuffle =
429
445
detail::enable_if_t <detail::is_arithmetic<T>::value, T>;
446
+ #else
447
+ template <typename T>
448
+ using EnableIfNativeShuffle = detail::enable_if_t <
449
+ std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t )), T>;
450
+
451
+ template <typename T>
452
+ using EnableIfVectorShuffle =
453
+ detail::enable_if_t <detail::is_vector_arithmetic<T>::value, T>;
454
+ #endif
455
+
456
+ #ifdef __NVPTX__
457
+ inline uint32_t membermask () {
458
+ uint32_t FULL_MASK = 0xFFFFFFFF ;
459
+ uint32_t max_size = __spirv_SubgroupMaxSize ();
460
+ uint32_t sg_size = __spirv_SubgroupSize ();
461
+ return FULL_MASK >> (max_size - sg_size);
462
+ }
463
+ #endif
430
464
431
465
template <typename T>
432
466
EnableIfNativeShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
467
+ #ifndef __NVPTX__
433
468
using OCLT = detail::ConvertToOpenCLType_t<T>;
434
469
return __spirv_SubgroupShuffleINTEL (OCLT (x),
435
470
static_cast <uint32_t >(local_id.get (0 )));
471
+ #else
472
+ return __nvvm_shfl_sync_idx_i32 (membermask (), x, local_id.get (0 ), 0x1f );
473
+ #endif
436
474
}
437
475
438
476
template <typename T>
439
477
EnableIfNativeShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
478
+ #ifndef __NVPTX__
440
479
using OCLT = detail::ConvertToOpenCLType_t<T>;
441
480
return __spirv_SubgroupShuffleXorINTEL (
442
481
OCLT (x), static_cast <uint32_t >(local_id.get (0 )));
482
+ #else
483
+ return __nvvm_shfl_sync_bfly_i32 (membermask (), x, local_id.get (0 ), 0x1f );
484
+ #endif
443
485
}
444
486
445
487
template <typename T>
446
488
EnableIfNativeShuffle<T> SubgroupShuffleDown (T x, id<1 > local_id) {
489
+ #ifndef __NVPTX__
447
490
using OCLT = detail::ConvertToOpenCLType_t<T>;
448
491
return __spirv_SubgroupShuffleDownINTEL (
449
492
OCLT (x), OCLT (x), static_cast <uint32_t >(local_id.get (0 )));
493
+ #else
494
+ return __nvvm_shfl_sync_down_i32 (membermask (), x, local_id.get (0 ), 0x1f );
495
+ #endif
450
496
}
451
497
452
498
template <typename T>
453
499
EnableIfNativeShuffle<T> SubgroupShuffleUp (T x, id<1 > local_id) {
500
+ #ifndef __NVPTX__
454
501
using OCLT = detail::ConvertToOpenCLType_t<T>;
455
502
return __spirv_SubgroupShuffleUpINTEL (OCLT (x), OCLT (x),
456
503
static_cast <uint32_t >(local_id.get (0 )));
504
+ #else
505
+ return __nvvm_shfl_sync_up_i32 (membermask (), x, local_id.get (0 ), 0 );
506
+ #endif
457
507
}
458
508
459
- // Bitcast shuffles can be implemented using a single SPIR-V SubgroupShuffle
509
+ #ifdef __NVPTX__
510
+ template <typename T>
511
+ EnableIfVectorShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
512
+ T result;
513
+ for (int s = 0 ; s < x.get_size (); ++s) {
514
+ result[s] = SubgroupShuffle (x[s], local_id);
515
+ }
516
+ return result;
517
+ }
518
+
519
+ template <typename T>
520
+ EnableIfVectorShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
521
+ T result;
522
+ for (int s = 0 ; s < x.get_size (); ++s) {
523
+ result[s] = SubgroupShuffleXor (x[s], local_id);
524
+ }
525
+ return result;
526
+ }
527
+
528
+ template <typename T>
529
+ EnableIfVectorShuffle<T> SubgroupShuffleDown (T x, id<1 > local_id) {
530
+ T result;
531
+ for (int s = 0 ; s < x.get_size (); ++s) {
532
+ result[s] = SubgroupShuffleDown (x[s], local_id);
533
+ }
534
+ return result;
535
+ }
536
+
537
+ template <typename T>
538
+ EnableIfVectorShuffle<T> SubgroupShuffleUp (T x, id<1 > local_id) {
539
+ T result;
540
+ for (int s = 0 ; s < x.get_size (); ++s) {
541
+ result[s] = SubgroupShuffleUp (x[s], local_id);
542
+ }
543
+ return result;
544
+ }
545
+ #endif
546
+
547
+ // Bitcast shuffles can be implemented using a single SubgroupShuffle
460
548
// intrinsic, but require type-punning via an appropriate integer type
549
+ #ifndef __NVPTX__
461
550
template <typename T>
462
551
using EnableIfBitcastShuffle =
463
552
detail::enable_if_t <!detail::is_arithmetic<T>::value &&
464
553
(std::is_trivially_copyable<T>::value &&
465
554
(sizeof (T) == 1 || sizeof (T) == 2 ||
466
555
sizeof (T) == 4 || sizeof (T) == 8 )),
467
556
T>;
557
+ #else
558
+ template <typename T>
559
+ using EnableIfBitcastShuffle = detail::enable_if_t <
560
+ !(std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t ))) &&
561
+ !detail::is_vector_arithmetic<T>::value &&
562
+ (std::is_trivially_copyable<T>::value &&
563
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
564
+ T>;
565
+ #endif
468
566
469
567
template <typename T>
470
568
using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t <T>;
@@ -473,57 +571,87 @@ template <typename T>
473
571
EnableIfBitcastShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
474
572
using ShuffleT = ConvertToNativeShuffleType_t<T>;
475
573
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
574
+ #ifndef __NVPTX__
476
575
ShuffleT Result = __spirv_SubgroupShuffleINTEL (
477
576
ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
577
+ #else
578
+ ShuffleT Result =
579
+ __nvvm_shfl_sync_idx_i32 (membermask (), ShuffleX, local_id.get (0 ), 0x1f );
580
+ #endif
478
581
return detail::bit_cast<T>(Result);
479
582
}
480
583
481
584
template <typename T>
482
585
EnableIfBitcastShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
483
586
using ShuffleT = ConvertToNativeShuffleType_t<T>;
484
587
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
588
+ #ifndef __NVPTX__
485
589
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL (
486
590
ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
591
+ #else
592
+ ShuffleT Result =
593
+ __nvvm_shfl_sync_bfly_i32 (membermask (), ShuffleX, local_id.get (0 ), 0x1f );
594
+ #endif
487
595
return detail::bit_cast<T>(Result);
488
596
}
489
597
490
598
template <typename T>
491
599
EnableIfBitcastShuffle<T> SubgroupShuffleDown (T x, id<1 > local_id) {
492
600
using ShuffleT = ConvertToNativeShuffleType_t<T>;
493
601
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
602
+ #ifndef __NVPTX__
494
603
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL (
495
604
ShuffleX, ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
605
+ #else
606
+ ShuffleT Result =
607
+ __nvvm_shfl_sync_down_i32 (membermask (), ShuffleX, local_id.get (0 ), 0x1f );
608
+ #endif
496
609
return detail::bit_cast<T>(Result);
497
610
}
498
611
499
612
template <typename T>
500
613
EnableIfBitcastShuffle<T> SubgroupShuffleUp (T x, id<1 > local_id) {
501
614
using ShuffleT = ConvertToNativeShuffleType_t<T>;
502
615
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
616
+ #ifndef __NVPTX__
503
617
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL (
504
618
ShuffleX, ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
619
+ #else
620
+ ShuffleT Result =
621
+ __nvvm_shfl_sync_up_i32 (membermask (), ShuffleX, local_id.get (0 ), 0 );
622
+ #endif
505
623
return detail::bit_cast<T>(Result);
506
624
}
507
625
508
- // Generic shuffles may require multiple calls to SPIR-V SubgroupShuffle
626
+ // Generic shuffles may require multiple calls to SubgroupShuffle
509
627
// intrinsics, and should use the fewest shuffles possible:
510
628
// - Loop over 64-bit chunks until remaining bytes < 64-bit
511
629
// - At most one 32-bit, 16-bit and 8-bit chunk left over
630
+ #ifndef __NVPTX__
512
631
template <typename T>
513
632
using EnableIfGenericShuffle =
514
633
detail::enable_if_t <!detail::is_arithmetic<T>::value &&
515
634
!(std::is_trivially_copyable<T>::value &&
516
635
(sizeof (T) == 1 || sizeof (T) == 2 ||
517
636
sizeof (T) == 4 || sizeof (T) == 8 )),
518
637
T>;
638
+ #else
639
+ template <typename T>
640
+ using EnableIfGenericShuffle = detail::enable_if_t <
641
+ !(std::is_integral<T>::value && (sizeof (T) <= sizeof (int32_t ))) &&
642
+ !detail::is_vector_arithmetic<T>::value &&
643
+ !(std::is_trivially_copyable<T>::value &&
644
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 )),
645
+ T>;
646
+ #endif
519
647
520
648
template <typename T>
521
649
EnableIfGenericShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
522
650
T Result;
523
651
char *XBytes = reinterpret_cast <char *>(&x);
524
652
char *ResultBytes = reinterpret_cast <char *>(&Result);
525
653
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
526
- uint64_t ShuffleX, ShuffleResult;
654
+ ShuffleChunkT ShuffleX, ShuffleResult;
527
655
detail::memcpy (&ShuffleX, XBytes + Offset, Size);
528
656
ShuffleResult = SubgroupShuffle (ShuffleX, local_id);
529
657
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
@@ -538,7 +666,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
538
666
char *XBytes = reinterpret_cast <char *>(&x);
539
667
char *ResultBytes = reinterpret_cast <char *>(&Result);
540
668
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
541
- uint64_t ShuffleX, ShuffleResult;
669
+ ShuffleChunkT ShuffleX, ShuffleResult;
542
670
detail::memcpy (&ShuffleX, XBytes + Offset, Size);
543
671
ShuffleResult = SubgroupShuffleXor (ShuffleX, local_id);
544
672
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
@@ -553,7 +681,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
553
681
char *XBytes = reinterpret_cast <char *>(&x);
554
682
char *ResultBytes = reinterpret_cast <char *>(&Result);
555
683
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
556
- uint64_t ShuffleX, ShuffleResult;
684
+ ShuffleChunkT ShuffleX, ShuffleResult;
557
685
detail::memcpy (&ShuffleX, XBytes + Offset, Size);
558
686
ShuffleResult = SubgroupShuffleDown (ShuffleX, local_id);
559
687
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
@@ -568,7 +696,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
568
696
char *XBytes = reinterpret_cast <char *>(&x);
569
697
char *ResultBytes = reinterpret_cast <char *>(&Result);
570
698
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
571
- uint64_t ShuffleX, ShuffleResult;
699
+ ShuffleChunkT ShuffleX, ShuffleResult;
572
700
detail::memcpy (&ShuffleX, XBytes + Offset, Size);
573
701
ShuffleResult = SubgroupShuffleUp (ShuffleX, local_id);
574
702
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
0 commit comments