@@ -33,6 +33,32 @@ template <> struct group_scope<::cl::sycl::intel::sub_group> {
33
33
static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
34
34
};
35
35
36
+ // Generic shuffles and broadcasts may require multiple calls to SPIR-V
37
+ // intrinsics, and should use the fewest broadcasts possible
38
+ // - Loop over 64-bit chunks until remaining bytes < 64-bit
39
+ // - At most one 32-bit, 16-bit and 8-bit chunk left over
40
+ template <typename T, typename Functor>
41
+ void GenericCall (const Functor &ApplyToBytes) {
42
+ if (sizeof (T) >= sizeof (uint64_t )) {
43
+ #pragma unroll
44
+ for (size_t Offset = 0 ; Offset < sizeof (T); Offset += sizeof (uint64_t )) {
45
+ ApplyToBytes (Offset, sizeof (uint64_t ));
46
+ }
47
+ }
48
+ if (sizeof (T) % sizeof (uint64_t ) >= sizeof (uint32_t )) {
49
+ size_t Offset = sizeof (T) / sizeof (uint64_t ) * sizeof (uint64_t );
50
+ ApplyToBytes (Offset, sizeof (uint32_t ));
51
+ }
52
+ if (sizeof (T) % sizeof (uint32_t ) >= sizeof (uint16_t )) {
53
+ size_t Offset = sizeof (T) / sizeof (uint32_t ) * sizeof (uint32_t );
54
+ ApplyToBytes (Offset, sizeof (uint16_t ));
55
+ }
56
+ if (sizeof (T) % sizeof (uint16_t ) >= sizeof (uint8_t )) {
57
+ size_t Offset = sizeof (T) / sizeof (uint16_t ) * sizeof (uint16_t );
58
+ ApplyToBytes (Offset, sizeof (uint8_t ));
59
+ }
60
+ }
61
+
36
62
template <typename Group> bool GroupAll (bool pred) {
37
63
return __spirv_GroupAll (group_scope<Group>::value, pred);
38
64
}
@@ -41,47 +67,137 @@ template <typename Group> bool GroupAny(bool pred) {
41
67
return __spirv_GroupAny (group_scope<Group>::value, pred);
42
68
}
43
69
70
+ // Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
71
+ template <typename T>
72
+ using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
73
+
74
+ template <typename T, typename IdT = size_t >
75
+ using EnableIfNativeBroadcast = detail::enable_if_t <
76
+ is_native_broadcast<T>::value && std::is_integral<IdT>::value, T>;
77
+
78
+ // Bitcast broadcasts can be implemented using a single SPIR-V GroupBroadcast
79
+ // intrinsic, but require type-punning via an appropriate integer type
80
+ template <typename T>
81
+ using is_bitcast_broadcast = bool_constant<
82
+ !is_native_broadcast<T>::value && std::is_trivially_copyable<T>::value &&
83
+ (sizeof (T) == 1 || sizeof (T) == 2 || sizeof (T) == 4 || sizeof (T) == 8 )>;
84
+
85
+ template <typename T, typename IdT = size_t >
86
+ using EnableIfBitcastBroadcast = detail::enable_if_t <
87
+ is_bitcast_broadcast<T>::value && std::is_integral<IdT>::value, T>;
88
+
89
+ template <typename T>
90
+ using ConvertToNativeBroadcastType_t = select_cl_scalar_integral_unsigned_t <T>;
91
+
92
+ // Generic broadcasts may require multiple calls to SPIR-V GroupBroadcast
93
+ // intrinsics, and should use the fewest broadcasts possible
94
+ // - Loop over 64-bit chunks until remaining bytes < 64-bit
95
+ // - At most one 32-bit, 16-bit and 8-bit chunk left over
96
+ template <typename T>
97
+ using is_generic_broadcast =
98
+ bool_constant<!is_native_broadcast<T>::value &&
99
+ !is_bitcast_broadcast<T>::value &&
100
+ std::is_trivially_copyable<T>::value>;
101
+
102
+ template <typename T, typename IdT = size_t >
103
+ using EnableIfGenericBroadcast = detail::enable_if_t <
104
+ is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;
105
+
44
106
// Broadcast with scalar local index
45
107
// Work-group supports any integral type
46
108
// Sub-group currently supports only uint32_t
109
+ template <typename Group> struct GroupId { using type = size_t ; };
110
+ template <> struct GroupId <::cl::sycl::intel::sub_group> {
111
+ using type = uint32_t ;
112
+ };
47
113
template <typename Group, typename T, typename IdT>
48
- detail::enable_if_t <is_group<Group>::value && std::is_integral<IdT>::value, T>
49
- GroupBroadcast (T x, IdT local_id) {
114
+ EnableIfNativeBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
115
+ using GroupIdT = typename GroupId<Group>::type;
116
+ GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
50
117
using OCLT = detail::ConvertToOpenCLType_t<T>;
51
- using OCLIdT = detail::ConvertToOpenCLType_t<IdT >;
52
- OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
53
- OCLIdT ocl_id = detail::convertDataToType<IdT , OCLIdT>(local_id );
54
- return __spirv_GroupBroadcast (group_scope<Group>::value, ocl_x, ocl_id );
118
+ using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT >;
119
+ OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
120
+ OCLIdT OCLId = detail::convertDataToType<GroupIdT , OCLIdT>(GroupLocalId );
121
+ return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId );
55
122
}
56
123
template <typename Group, typename T, typename IdT>
57
- detail::enable_if_t <is_sub_group<Group>::value && std::is_integral<IdT>::value,
58
- T>
59
- GroupBroadcast (T x, IdT local_id) {
60
- using SGIdT = uint32_t ;
61
- SGIdT sg_local_id = static_cast <SGIdT>(local_id);
62
- using OCLT = detail::ConvertToOpenCLType_t<T>;
63
- using OCLIdT = detail::ConvertToOpenCLType_t<SGIdT>;
64
- OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
65
- OCLIdT ocl_id = detail::convertDataToType<SGIdT, OCLIdT>(sg_local_id);
66
- return __spirv_GroupBroadcast (group_scope<Group>::value, ocl_x, ocl_id);
124
+ EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
125
+ using GroupIdT = typename GroupId<Group>::type;
126
+ GroupIdT GroupLocalId = static_cast <GroupIdT>(local_id);
127
+ using BroadcastT = ConvertToNativeBroadcastType_t<T>;
128
+ using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
129
+ auto BroadcastX = detail::bit_cast<BroadcastT>(x);
130
+ OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
131
+ BroadcastT Result =
132
+ __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
133
+ return detail::bit_cast<T>(Result);
134
+ }
135
+ template <typename Group, typename T, typename IdT>
136
+ EnableIfGenericBroadcast<T, IdT> GroupBroadcast (T x, IdT local_id) {
137
+ T Result;
138
+ char *XBytes = reinterpret_cast <char *>(&x);
139
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
140
+ auto BroadcastBytes = [=](size_t Offset, size_t Size) {
141
+ uint64_t BroadcastX, BroadcastResult;
142
+ detail::memcpy (&BroadcastX, XBytes + Offset, Size);
143
+ BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
144
+ detail::memcpy (ResultBytes + Offset, &BroadcastResult, Size);
145
+ };
146
+ GenericCall<T>(BroadcastBytes);
147
+ return Result;
67
148
}
68
149
69
150
// Broadcast with vector local index
70
151
template <typename Group, typename T, int Dimensions>
71
- T GroupBroadcast (T x, id<Dimensions> local_id) {
152
+ EnableIfNativeBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
72
153
if (Dimensions == 1 ) {
73
154
return GroupBroadcast<Group>(x, local_id[0 ]);
74
155
}
75
156
using IdT = vec<size_t , Dimensions>;
76
157
using OCLT = detail::ConvertToOpenCLType_t<T>;
77
158
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
78
- IdT vec_id;
159
+ IdT VecId;
160
+ for (int i = 0 ; i < Dimensions; ++i) {
161
+ VecId[i] = local_id[Dimensions - i - 1 ];
162
+ }
163
+ OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
164
+ OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
165
+ return __spirv_GroupBroadcast (group_scope<Group>::value, OCLX, OCLId);
166
+ }
167
+ template <typename Group, typename T, int Dimensions>
168
+ EnableIfBitcastBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
169
+ if (Dimensions == 1 ) {
170
+ return GroupBroadcast<Group>(x, local_id[0 ]);
171
+ }
172
+ using IdT = vec<size_t , Dimensions>;
173
+ using BroadcastT = ConvertToNativeBroadcastType_t<T>;
174
+ using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
175
+ IdT VecId;
79
176
for (int i = 0 ; i < Dimensions; ++i) {
80
- vec_id [i] = local_id[Dimensions - i - 1 ];
177
+ VecId [i] = local_id[Dimensions - i - 1 ];
81
178
}
82
- OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
83
- OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(vec_id);
84
- return __spirv_GroupBroadcast (group_scope<Group>::value, ocl_x, ocl_id);
179
+ auto BroadcastX = detail::bit_cast<BroadcastT>(x);
180
+ OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
181
+ BroadcastT Result =
182
+ __spirv_GroupBroadcast (group_scope<Group>::value, BroadcastX, OCLId);
183
+ return detail::bit_cast<T>(Result);
184
+ }
185
+ template <typename Group, typename T, int Dimensions>
186
+ EnableIfGenericBroadcast<T> GroupBroadcast (T x, id<Dimensions> local_id) {
187
+ if (Dimensions == 1 ) {
188
+ return GroupBroadcast<Group>(x, local_id[0 ]);
189
+ }
190
+ T Result;
191
+ char *XBytes = reinterpret_cast <char *>(&x);
192
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
193
+ auto BroadcastBytes = [=](size_t Offset, size_t Size) {
194
+ uint64_t BroadcastX, BroadcastResult;
195
+ detail::memcpy (&BroadcastX, XBytes + Offset, Size);
196
+ BroadcastResult = GroupBroadcast<Group>(BroadcastX, local_id);
197
+ detail::memcpy (ResultBytes + Offset, &BroadcastResult, Size);
198
+ };
199
+ GenericCall<T>(BroadcastBytes);
200
+ return Result;
85
201
}
86
202
87
203
// Single happens-before means semantics should always apply to all spaces
@@ -400,28 +516,6 @@ using EnableIfGenericShuffle =
400
516
sizeof (T) == 4 || sizeof (T) == 8 )),
401
517
T>;
402
518
403
- template <typename T, typename ShuffleFunctor>
404
- void GenericShuffle (const ShuffleFunctor &ShuffleBytes) {
405
- if (sizeof (T) >= sizeof (uint64_t )) {
406
- #pragma unroll
407
- for (size_t Offset = 0 ; Offset < sizeof (T); Offset += sizeof (uint64_t )) {
408
- ShuffleBytes (Offset, sizeof (uint64_t ));
409
- }
410
- }
411
- if (sizeof (T) % sizeof (uint64_t ) >= sizeof (uint32_t )) {
412
- size_t Offset = sizeof (T) / sizeof (uint64_t ) * sizeof (uint64_t );
413
- ShuffleBytes (Offset, sizeof (uint32_t ));
414
- }
415
- if (sizeof (T) % sizeof (uint32_t ) >= sizeof (uint16_t )) {
416
- size_t Offset = sizeof (T) / sizeof (uint32_t ) * sizeof (uint32_t );
417
- ShuffleBytes (Offset, sizeof (uint16_t ));
418
- }
419
- if (sizeof (T) % sizeof (uint16_t ) >= sizeof (uint8_t )) {
420
- size_t Offset = sizeof (T) / sizeof (uint16_t ) * sizeof (uint16_t );
421
- ShuffleBytes (Offset, sizeof (uint8_t ));
422
- }
423
- }
424
-
425
519
template <typename T>
426
520
EnableIfGenericShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
427
521
T Result;
@@ -433,7 +527,7 @@ EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
433
527
ShuffleResult = SubgroupShuffle (ShuffleX, local_id);
434
528
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
435
529
};
436
- GenericShuffle <T>(ShuffleBytes);
530
+ GenericCall <T>(ShuffleBytes);
437
531
return Result;
438
532
}
439
533
@@ -448,7 +542,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
448
542
ShuffleResult = SubgroupShuffleXor (ShuffleX, local_id);
449
543
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
450
544
};
451
- GenericShuffle <T>(ShuffleBytes);
545
+ GenericCall <T>(ShuffleBytes);
452
546
return Result;
453
547
}
454
548
@@ -465,7 +559,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, T y, id<1> local_id) {
465
559
ShuffleResult = SubgroupShuffleDown (ShuffleX, ShuffleY, local_id);
466
560
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
467
561
};
468
- GenericShuffle <T>(ShuffleBytes);
562
+ GenericCall <T>(ShuffleBytes);
469
563
return Result;
470
564
}
471
565
@@ -482,7 +576,7 @@ EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, T y, id<1> local_id) {
482
576
ShuffleResult = SubgroupShuffleUp (ShuffleX, ShuffleY, local_id);
483
577
detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
484
578
};
485
- GenericShuffle <T>(ShuffleBytes);
579
+ GenericCall <T>(ShuffleBytes);
486
580
return Result;
487
581
}
488
582
0 commit comments