Skip to content

Commit d3c7b20

Browse files
authored
[SYCL] Extend shuffles to TriviallyCopyable types (#2017)
Any 8-bit, 16-bit, 32-bit or 64-bit type that is TriviallyCopyable can be implemented by existing shuffles and bit_cast. Any TriviallyCopyable type can be shuffled by breaking it into smaller chunks and shuffling each chunk in turn. The current approach divides types into 64-bit chunks, then falls back to 32-, 16- or 8-bit chunks to handle the remainder. Signed-off-by: John Pennycook <[email protected]>
1 parent 011c2ca commit d3c7b20

File tree

4 files changed

+422
-29
lines changed

4 files changed

+422
-29
lines changed

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,188 @@ AtomicMax(multi_ptr<T, AddressSpace> MPtr, intel::memory_scope Scope,
290290
return __spirv_AtomicMax(Ptr, SPIRVScope, SPIRVOrder, Value);
291291
}
292292

293+
// Native shuffles map directly to a SPIR-V SubgroupShuffle intrinsic
294+
template <typename T>
295+
using EnableIfNativeShuffle =
296+
detail::enable_if_t<detail::is_arithmetic<T>::value, T>;
297+
298+
template <typename T>
299+
EnableIfNativeShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
300+
using OCLT = detail::ConvertToOpenCLType_t<T>;
301+
return __spirv_SubgroupShuffleINTEL(OCLT(x),
302+
static_cast<uint32_t>(local_id.get(0)));
303+
}
304+
305+
template <typename T>
306+
EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
307+
using OCLT = detail::ConvertToOpenCLType_t<T>;
308+
return __spirv_SubgroupShuffleXorINTEL(
309+
OCLT(x), static_cast<uint32_t>(local_id.get(0)));
310+
}
311+
312+
template <typename T>
313+
EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, T y, id<1> local_id) {
314+
using OCLT = detail::ConvertToOpenCLType_t<T>;
315+
return __spirv_SubgroupShuffleDownINTEL(
316+
OCLT(x), OCLT(y), static_cast<uint32_t>(local_id.get(0)));
317+
}
318+
319+
template <typename T>
320+
EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, T y, id<1> local_id) {
321+
using OCLT = detail::ConvertToOpenCLType_t<T>;
322+
return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(y),
323+
static_cast<uint32_t>(local_id.get(0)));
324+
}
325+
326+
// Bitcast shuffles can be implemented using a single SPIR-V SubgroupShuffle
327+
// intrinsic, but require type-punning via an appropriate integer type
328+
template <typename T>
329+
using EnableIfBitcastShuffle =
330+
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
331+
(std::is_trivially_copyable<T>::value &&
332+
(sizeof(T) == 1 || sizeof(T) == 2 ||
333+
sizeof(T) == 4 || sizeof(T) == 8)),
334+
T>;
335+
336+
template <typename T>
337+
using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t<T>;
338+
339+
template <typename T>
340+
EnableIfBitcastShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
341+
using ShuffleT = ConvertToNativeShuffleType_t<T>;
342+
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
343+
ShuffleT Result = __spirv_SubgroupShuffleINTEL(
344+
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
345+
return detail::bit_cast<T>(Result);
346+
}
347+
348+
template <typename T>
349+
EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
350+
using ShuffleT = ConvertToNativeShuffleType_t<T>;
351+
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
352+
ShuffleT Result = __spirv_SubgroupShuffleXorINTEL(
353+
ShuffleX, static_cast<uint32_t>(local_id.get(0)));
354+
return detail::bit_cast<T>(Result);
355+
}
356+
357+
template <typename T>
358+
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, T y, id<1> local_id) {
359+
using ShuffleT = ConvertToNativeShuffleType_t<T>;
360+
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
361+
auto ShuffleY = detail::bit_cast<ShuffleT>(y);
362+
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(
363+
ShuffleX, ShuffleY, static_cast<uint32_t>(local_id.get(0)));
364+
return detail::bit_cast<T>(Result);
365+
}
366+
367+
template <typename T>
368+
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, T y, id<1> local_id) {
369+
using ShuffleT = ConvertToNativeShuffleType_t<T>;
370+
auto ShuffleX = detail::bit_cast<ShuffleT>(x);
371+
auto ShuffleY = detail::bit_cast<ShuffleT>(y);
372+
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(
373+
ShuffleX, ShuffleY, static_cast<uint32_t>(local_id.get(0)));
374+
return detail::bit_cast<T>(Result);
375+
}
376+
377+
// Generic shuffles may require multiple calls to SPIR-V SubgroupShuffle
378+
// intrinsics, and should use the fewest shuffles possible:
379+
// - Loop over 64-bit chunks until remaining bytes < 64-bit
380+
// - At most one 32-bit, 16-bit and 8-bit chunk left over
381+
template <typename T>
382+
using EnableIfGenericShuffle =
383+
detail::enable_if_t<!detail::is_arithmetic<T>::value &&
384+
!(std::is_trivially_copyable<T>::value &&
385+
(sizeof(T) == 1 || sizeof(T) == 2 ||
386+
sizeof(T) == 4 || sizeof(T) == 8)),
387+
T>;
388+
389+
template <typename T, typename ShuffleFunctor>
390+
void GenericShuffle(const ShuffleFunctor &ShuffleBytes) {
391+
if (sizeof(T) >= sizeof(uint64_t)) {
392+
#pragma unroll
393+
for (size_t Offset = 0; Offset < sizeof(T); Offset += sizeof(uint64_t)) {
394+
ShuffleBytes(Offset, sizeof(uint64_t));
395+
}
396+
}
397+
if (sizeof(T) % sizeof(uint64_t) >= sizeof(uint32_t)) {
398+
size_t Offset = sizeof(T) / sizeof(uint64_t) * sizeof(uint64_t);
399+
ShuffleBytes(Offset, sizeof(uint32_t));
400+
}
401+
if (sizeof(T) % sizeof(uint32_t) >= sizeof(uint16_t)) {
402+
size_t Offset = sizeof(T) / sizeof(uint32_t) * sizeof(uint32_t);
403+
ShuffleBytes(Offset, sizeof(uint16_t));
404+
}
405+
if (sizeof(T) % sizeof(uint16_t) >= sizeof(uint8_t)) {
406+
size_t Offset = sizeof(T) / sizeof(uint16_t) * sizeof(uint16_t);
407+
ShuffleBytes(Offset, sizeof(uint8_t));
408+
}
409+
}
410+
411+
template <typename T>
412+
EnableIfGenericShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
413+
T Result;
414+
char *XBytes = reinterpret_cast<char *>(&x);
415+
char *ResultBytes = reinterpret_cast<char *>(&Result);
416+
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
417+
uint64_t ShuffleX, ShuffleResult;
418+
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
419+
ShuffleResult = SubgroupShuffle(ShuffleX, local_id);
420+
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
421+
};
422+
GenericShuffle<T>(ShuffleBytes);
423+
return Result;
424+
}
425+
426+
template <typename T>
427+
EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
428+
T Result;
429+
char *XBytes = reinterpret_cast<char *>(&x);
430+
char *ResultBytes = reinterpret_cast<char *>(&Result);
431+
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
432+
uint64_t ShuffleX, ShuffleResult;
433+
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
434+
ShuffleResult = SubgroupShuffleXor(ShuffleX, local_id);
435+
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
436+
};
437+
GenericShuffle<T>(ShuffleBytes);
438+
return Result;
439+
}
440+
441+
template <typename T>
442+
EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, T y, id<1> local_id) {
443+
T Result;
444+
char *XBytes = reinterpret_cast<char *>(&x);
445+
char *YBytes = reinterpret_cast<char *>(&y);
446+
char *ResultBytes = reinterpret_cast<char *>(&Result);
447+
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
448+
uint64_t ShuffleX, ShuffleY, ShuffleResult;
449+
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
450+
detail::memcpy(&ShuffleY, YBytes + Offset, Size);
451+
ShuffleResult = SubgroupShuffleDown(ShuffleX, ShuffleY, local_id);
452+
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
453+
};
454+
GenericShuffle<T>(ShuffleBytes);
455+
return Result;
456+
}
457+
458+
template <typename T>
459+
EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, T y, id<1> local_id) {
460+
T Result;
461+
char *XBytes = reinterpret_cast<char *>(&x);
462+
char *YBytes = reinterpret_cast<char *>(&y);
463+
char *ResultBytes = reinterpret_cast<char *>(&Result);
464+
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
465+
uint64_t ShuffleX, ShuffleY, ShuffleResult;
466+
detail::memcpy(&ShuffleX, XBytes + Offset, Size);
467+
detail::memcpy(&ShuffleY, YBytes + Offset, Size);
468+
ShuffleResult = SubgroupShuffleUp(ShuffleX, ShuffleY, local_id);
469+
detail::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
470+
};
471+
GenericShuffle<T>(ShuffleBytes);
472+
return Result;
473+
}
474+
293475
} // namespace spirv
294476
} // namespace detail
295477
} // namespace sycl

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,6 @@ using AcceptableForLocalLoadStore =
4646
Space == access::address_space::local_space>;
4747

4848
#ifdef __SYCL_DEVICE_ONLY__
49-
#define __SYCL_SG_GENERATE_BODY_1ARG(name, SPIRVOperation) \
50-
template <typename T> T name(T x, id<1> local_id) { \
51-
using OCLT = sycl::detail::ConvertToOpenCLType_t<T>; \
52-
return __spirv_##SPIRVOperation(OCLT(x), local_id.get(0)); \
53-
}
54-
55-
__SYCL_SG_GENERATE_BODY_1ARG(shuffle, SubgroupShuffleINTEL)
56-
__SYCL_SG_GENERATE_BODY_1ARG(shuffle_xor, SubgroupShuffleXorINTEL)
57-
58-
#undef __SYCL_SG_GENERATE_BODY_1ARG
59-
60-
#define __SYCL_SG_GENERATE_BODY_2ARG(name, SPIRVOperation) \
61-
template <typename T> T name(T A, T B, uint32_t Delta) { \
62-
using OCLT = sycl::detail::ConvertToOpenCLType_t<T>; \
63-
return __spirv_##SPIRVOperation(OCLT(A), OCLT(B), Delta); \
64-
}
65-
66-
__SYCL_SG_GENERATE_BODY_2ARG(shuffle_down, SubgroupShuffleDownINTEL)
67-
__SYCL_SG_GENERATE_BODY_2ARG(shuffle_up, SubgroupShuffleUpINTEL)
68-
69-
#undef __SYCL_SG_GENERATE_BODY_2ARG
70-
7149
template <typename T, access::address_space Space>
7250
T load(const multi_ptr<T, Space> src) {
7351
using BlockT = SelectBlockT<T>;
@@ -202,7 +180,7 @@ struct sub_group {
202180

203181
template <typename T> T shuffle(T x, id_type local_id) const {
204182
#ifdef __SYCL_DEVICE_ONLY__
205-
return sycl::detail::sub_group::shuffle(x, local_id);
183+
return sycl::detail::spirv::SubgroupShuffle(x, local_id);
206184
#else
207185
(void)x;
208186
(void)local_id;
@@ -213,7 +191,7 @@ struct sub_group {
213191

214192
template <typename T> T shuffle_down(T x, uint32_t delta) const {
215193
#ifdef __SYCL_DEVICE_ONLY__
216-
return sycl::detail::sub_group::shuffle_down(x, x, delta);
194+
return sycl::detail::spirv::SubgroupShuffleDown(x, x, delta);
217195
#else
218196
(void)x;
219197
(void)delta;
@@ -224,7 +202,7 @@ struct sub_group {
224202

225203
template <typename T> T shuffle_up(T x, uint32_t delta) const {
226204
#ifdef __SYCL_DEVICE_ONLY__
227-
return sycl::detail::sub_group::shuffle_up(x, x, delta);
205+
return sycl::detail::spirv::SubgroupShuffleUp(x, x, delta);
228206
#else
229207
(void)x;
230208
(void)delta;
@@ -235,7 +213,7 @@ struct sub_group {
235213

236214
template <typename T> T shuffle_xor(T x, id_type value) const {
237215
#ifdef __SYCL_DEVICE_ONLY__
238-
return sycl::detail::sub_group::shuffle_xor(x, value);
216+
return sycl::detail::spirv::SubgroupShuffleXor(x, value);
239217
#else
240218
(void)x;
241219
(void)value;
@@ -251,7 +229,7 @@ struct sub_group {
251229
__SYCL_DEPRECATED("Two-input sub-group shuffles are deprecated.")
252230
T shuffle(T x, T y, id_type local_id) const {
253231
#ifdef __SYCL_DEVICE_ONLY__
254-
return sycl::detail::sub_group::shuffle_down(
232+
return sycl::detail::spirv::SubgroupShuffleDown(
255233
x, y, (local_id - get_local_id()).get(0));
256234
#else
257235
(void)x;
@@ -266,7 +244,7 @@ struct sub_group {
266244
__SYCL_DEPRECATED("Two-input sub-group shuffles are deprecated.")
267245
T shuffle_down(T current, T next, uint32_t delta) const {
268246
#ifdef __SYCL_DEVICE_ONLY__
269-
return sycl::detail::sub_group::shuffle_down(current, next, delta);
247+
return sycl::detail::spirv::SubgroupShuffleDown(current, next, delta);
270248
#else
271249
(void)current;
272250
(void)next;
@@ -280,7 +258,7 @@ struct sub_group {
280258
__SYCL_DEPRECATED("Two-input sub-group shuffles are deprecated.")
281259
T shuffle_up(T previous, T current, uint32_t delta) const {
282260
#ifdef __SYCL_DEVICE_ONLY__
283-
return sycl::detail::sub_group::shuffle_up(previous, current, delta);
261+
return sycl::detail::spirv::SubgroupShuffleUp(previous, current, delta);
284262
#else
285263
(void)previous;
286264
(void)current;

0 commit comments

Comments
 (0)