Skip to content

Commit d99e957

Browse files
authored
[SYCL][AMDGCN] Fix up and down shuffles and reductions (#5359)
This patch fixes the group collective implementation for AMDGCN, which had two main issues, in one place it was calling a regular `shuffle` instead of a `shuffleUp` which ended up breaking the reduction algorithm. In addition it was also not using the correct interface for the SPIR-V `shuffleUp` function. Which leads to the second part of this patch which fixes the `shuffleUp` and `shuffleDown` functions, mostly for the AMDGCN built-ins but also in the SYCL header, as the SYCL built-ins were not implemented properly on top of the SPIR-V built-ins. At the SYCL level, the `shuffleUp` and `shuffleDown` built-ins take a value to participate in the shuffle and a delta. The delta is used to compute which thread to take the value from during the shuffle operation. For `shuffleUp` it will be substracted from the thread id, and for `shuffleDown` it will be added. And so in SYCL this delta must be defined such as `subgroup_local_id - delta` falls within `[0, subgroup_local_size[` for `shuffleUp`, and `subgroup_local_id + delta` falls within `[0, subgroup_local_size[` for `shuffleDown`. However in SPIR-V, these built-ins are a bit more complicated and take two values to participate in the shuffle and support twice the delta range as the SYCL built-ins. For example for `shuffleUp` the valid range for `subgroup_local_id - delta` is `[-subgroup_local_size, subgroup_local_size[` and in this instance if it falls within `[-subgroup_local_size, 0[` the first value will be used to participate in the shuffle, and if it falls within `[0, subgroup_local_size[` the second value will be used to participate in the shuffle. And it works in a similar way for `shuffleDown`. And so when implementing the SYCL built-ins using the SPIR-V built-ins, only half of the range can be used in a properly defined way, which means only one of the value parameters of the SPIR-V built-ins actually matters. Therefore the SYCL built-ins are implemented passing in the same value to both value parameters of the SPIR-V built-ins. The complete definition of the SPIR-V built-ins can be found here: * https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroups.asciidoc#instructions Using defines to figure out the wavefront size there is incorrect because libclc is not built for a specific amdgcn version, so it will always default to `64`. Instead use the `__oclc_wavefront64` global variable provided by ROCm, which will be set to a different value depending on the architecture.
1 parent c6b7b8e commit d99e957

File tree

4 files changed

+106
-84
lines changed

4 files changed

+106
-84
lines changed

libclc/amdgcn-amdhsa/libspirv/group/collectives.cl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ __clc__get_group_scratch_double() __asm("__clc__get_group_scratch_double");
4242
_CLC_DECL TYPE _Z28__spirv_SubgroupShuffleINTELI##TYPE_MANGLED##ET_S0_j( \
4343
TYPE, int); \
4444
_CLC_DECL TYPE \
45-
_Z30__spirv_SubgroupShuffleUpINTELI##TYPE_MANGLED##ET_S0_S0_j(TYPE, \
46-
int);
45+
_Z30__spirv_SubgroupShuffleUpINTELI##TYPE_MANGLED##ET_S0_S0_j( \
46+
TYPE, TYPE, unsigned int);
4747

4848
__CLC_DECLARE_SHUFFLES(char, a);
4949
__CLC_DECLARE_SHUFFLES(unsigned char, h);
@@ -72,7 +72,8 @@ __CLC_DECLARE_SHUFFLES(double, d);
7272
/* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
7373
for (int o = 1; o < __spirv_SubgroupMaxSize(); o *= 2) { \
7474
TYPE contribution = \
75-
_Z28__spirv_SubgroupShuffleINTELI##TYPE_MANGLED##ET_S0_j(x, o); \
75+
_Z30__spirv_SubgroupShuffleUpINTELI##TYPE_MANGLED##ET_S0_S0_j(x, x, \
76+
o); \
7677
bool inactive = (sg_lid < o); \
7778
contribution = (inactive) ? IDENTITY : contribution; \
7879
x = OP(x, contribution); \
@@ -90,8 +91,8 @@ __CLC_DECLARE_SHUFFLES(double, d);
9091
} /* For ExclusiveScan, shift and prepend identity */ \
9192
else if (op == ExclusiveScan) { \
9293
*carry = x; \
93-
result = \
94-
_Z30__spirv_SubgroupShuffleUpINTELI##TYPE_MANGLED##ET_S0_S0_j(x, 1); \
94+
result = _Z30__spirv_SubgroupShuffleUpINTELI##TYPE_MANGLED##ET_S0_S0_j( \
95+
x, x, 1); \
9596
if (sg_lid == 0) { \
9697
result = IDENTITY; \
9798
} \

libclc/amdgcn-amdhsa/libspirv/misc/sub_group_shuffle.cl

Lines changed: 71 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -256,12 +256,26 @@ __AMDGCN_CLC_SUBGROUP_XOR_TO_VEC(double16, d, 16)
256256
// Shuffle Up
257257
// int __spirv_SubgroupShuffleUpINTEL<int>(int, int, unsigned int)
258258
_CLC_DEF int
259-
_Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j(int var, int lane_delta,
260-
unsigned int width) {
259+
_Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j(int previous, int current,
260+
unsigned int delta) {
261261
int self = SELF;
262-
int index = self - lane_delta;
263-
index = (index < (self & ~(width - 1))) ? index : self;
264-
return __builtin_amdgcn_ds_bpermute(index << 2, var);
262+
int size = SUBGROUP_SIZE;
263+
264+
int index = self - delta;
265+
266+
int val;
267+
if (index >= 0 && index < size) {
268+
val = current;
269+
} else if (index < 0 && index > -size) {
270+
val = previous;
271+
index = index + size;
272+
} else {
273+
// index out of bounds so return arbitrary data
274+
val = current;
275+
index = self;
276+
}
277+
278+
return __builtin_amdgcn_ds_bpermute(index << 2, val);
265279
}
266280

267281
// Sub 32-bit types.
@@ -272,9 +286,9 @@ _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j(int var, int lane_delta,
272286
#define __AMDGCN_CLC_SUBGROUP_UP_SUB_I32(TYPE, MANGLED_TYPE_NAME) \
273287
_CLC_DEF TYPE \
274288
_Z30__spirv_SubgroupShuffleUpINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
275-
TYPE var, TYPE lane_delta, unsigned int width) { \
276-
return _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j(var, lane_delta, \
277-
width); \
289+
TYPE previous, TYPE current, unsigned int delta) { \
290+
return _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j(previous, current, \
291+
delta); \
278292
}
279293
__AMDGCN_CLC_SUBGROUP_UP_SUB_I32(char, a);
280294
__AMDGCN_CLC_SUBGROUP_UP_SUB_I32(unsigned char, h);
@@ -288,9 +302,9 @@ __AMDGCN_CLC_SUBGROUP_UP_SUB_I32(unsigned short, t);
288302
#define __AMDGCN_CLC_SUBGROUP_UP_I32(TYPE, CAST_TYPE, MANGLED_TYPE_NAME) \
289303
_CLC_DEF TYPE \
290304
_Z30__spirv_SubgroupShuffleUpINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
291-
TYPE var, TYPE lane_delta, unsigned int width) { \
305+
TYPE previous, TYPE current, unsigned int delta) { \
292306
return __builtin_astype(_Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j( \
293-
as_int(var), as_int(lane_delta), width), \
307+
as_int(previous), as_int(current), delta), \
294308
CAST_TYPE); \
295309
}
296310
__AMDGCN_CLC_SUBGROUP_UP_I32(unsigned int, uint, j);
@@ -304,13 +318,15 @@ __AMDGCN_CLC_SUBGROUP_UP_I32(float, float, f);
304318
#define __AMDGCN_CLC_SUBGROUP_UP_I64(TYPE, CAST_TYPE, MANGLED_TYPE_NAME) \
305319
_CLC_DEF TYPE \
306320
_Z30__spirv_SubgroupShuffleUpINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
307-
TYPE var, TYPE lane_delta, unsigned int width) { \
308-
int2 tmp = as_int2(var); \
309-
tmp.lo = _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j( \
310-
tmp.lo, (int)lane_delta, width); \
311-
tmp.hi = _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j( \
312-
tmp.hi, (int)lane_delta, width); \
313-
return __builtin_astype(tmp, CAST_TYPE); \
321+
TYPE previous, TYPE current, unsigned int delta) { \
322+
int2 tmp_previous = as_int2(previous); \
323+
int2 tmp_current = as_int2(current); \
324+
int2 ret; \
325+
ret.lo = _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j( \
326+
tmp_previous.lo, tmp_current.lo, delta); \
327+
ret.hi = _Z30__spirv_SubgroupShuffleUpINTELIiET_S0_S0_j( \
328+
tmp_previous.hi, tmp_current.hi, delta); \
329+
return __builtin_astype(ret, CAST_TYPE); \
314330
}
315331
__AMDGCN_CLC_SUBGROUP_UP_I64(long, long, l);
316332
__AMDGCN_CLC_SUBGROUP_UP_I64(unsigned long, ulong, m);
@@ -321,12 +337,12 @@ __AMDGCN_CLC_SUBGROUP_UP_I64(double, double, d);
321337
#define __AMDGCN_CLC_SUBGROUP_UP_TO_VEC(TYPE, MANGLED_SCALAR_TY, NUM_ELEMS) \
322338
_CLC_DEF TYPE \
323339
_Z30__spirv_SubgroupShuffleUpINTELIDv##NUM_ELEMS##_##MANGLED_SCALAR_TY##ET_S1_S1_j( \
324-
TYPE var, TYPE lane_delta, unsigned int width) { \
340+
TYPE previous, TYPE current, unsigned int delta) { \
325341
TYPE res; \
326342
for (int i = 0; i < NUM_ELEMS; ++i) { \
327343
res[i] = \
328344
_Z30__spirv_SubgroupShuffleUpINTELI##MANGLED_SCALAR_TY##ET_S0_S0_j( \
329-
var[i], (unsigned int)lane_delta[0], width); \
345+
previous[i], current[i], delta); \
330346
} \
331347
return res; \
332348
}
@@ -381,12 +397,26 @@ __AMDGCN_CLC_SUBGROUP_UP_TO_VEC(double16, d, 16)
381397
// Shuffle Down
382398
// int __spirv_SubgroupShuffleDownINTEL<int>(int, int, unsigned int)
383399
_CLC_DEF int
384-
_Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j(int var, int lane_delta,
385-
unsigned int width) {
386-
unsigned int self = SELF;
387-
unsigned int index = self + lane_delta;
388-
index = as_uint(((self & (width - 1)) + lane_delta)) >= width ? self : index;
389-
return __builtin_amdgcn_ds_bpermute(index << 2, var);
400+
_Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j(int current, int next,
401+
unsigned int delta) {
402+
int self = SELF;
403+
int size = SUBGROUP_SIZE;
404+
405+
int index = self + delta;
406+
407+
int val;
408+
if (index < size) {
409+
val = current;
410+
} else if (index < 2 * size) {
411+
val = next;
412+
index = index - size;
413+
} else {
414+
// index out of bounds so return arbitrary data
415+
val = current;
416+
index = self;
417+
}
418+
419+
return __builtin_amdgcn_ds_bpermute(index << 2, val);
390420
}
391421

392422
// Sub 32-bit types.
@@ -397,9 +427,9 @@ _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j(int var, int lane_delta,
397427
#define __AMDGCN_CLC_SUBGROUP_DOWN_TO_I32(TYPE, MANGLED_TYPE_NAME) \
398428
_CLC_DEF TYPE \
399429
_Z32__spirv_SubgroupShuffleDownINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
400-
TYPE var, TYPE lane_delta, unsigned int width) { \
401-
return _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j(var, lane_delta, \
402-
width); \
430+
TYPE current, TYPE next, unsigned int delta) { \
431+
return _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j(current, next, \
432+
delta); \
403433
}
404434
__AMDGCN_CLC_SUBGROUP_DOWN_TO_I32(char, a);
405435
__AMDGCN_CLC_SUBGROUP_DOWN_TO_I32(unsigned char, h);
@@ -413,9 +443,9 @@ __AMDGCN_CLC_SUBGROUP_DOWN_TO_I32(unsigned short, t);
413443
#define __AMDGCN_CLC_SUBGROUP_DOWN_I32(TYPE, CAST_TYPE, MANGLED_TYPE_NAME) \
414444
_CLC_DEF TYPE \
415445
_Z32__spirv_SubgroupShuffleDownINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
416-
TYPE var, TYPE lane_delta, unsigned int width) { \
446+
TYPE current, TYPE next, unsigned int delta) { \
417447
return __builtin_astype(_Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j( \
418-
as_int(var), as_int(lane_delta), width), \
448+
as_int(current), as_int(next), delta), \
419449
CAST_TYPE); \
420450
}
421451
__AMDGCN_CLC_SUBGROUP_DOWN_I32(unsigned int, uint, j);
@@ -427,13 +457,15 @@ __AMDGCN_CLC_SUBGROUP_DOWN_I32(float, float, f);
427457
#define __AMDGCN_CLC_SUBGROUP_DOWN_I64(TYPE, CAST_TYPE, MANGLED_TYPE_NAME) \
428458
_CLC_DEF TYPE \
429459
_Z32__spirv_SubgroupShuffleDownINTELI##MANGLED_TYPE_NAME##ET_S0_S0_j( \
430-
TYPE var, TYPE lane_delta, unsigned int width) { \
431-
int2 tmp = as_int2(var); \
432-
tmp.lo = _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j( \
433-
tmp.lo, (int)lane_delta, width); \
434-
tmp.hi = _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j( \
435-
tmp.hi, (int)lane_delta, width); \
436-
return __builtin_astype(tmp, CAST_TYPE); \
460+
TYPE current, TYPE next, unsigned int delta) { \
461+
int2 tmp_current = as_int2(current); \
462+
int2 tmp_next = as_int2(next); \
463+
int2 ret; \
464+
ret.lo = _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j( \
465+
tmp_current.lo, tmp_next.lo, delta); \
466+
ret.hi = _Z32__spirv_SubgroupShuffleDownINTELIiET_S0_S0_j( \
467+
tmp_current.hi, tmp_next.hi, delta); \
468+
return __builtin_astype(ret, CAST_TYPE); \
437469
}
438470
__AMDGCN_CLC_SUBGROUP_DOWN_I64(long, long, l);
439471
__AMDGCN_CLC_SUBGROUP_DOWN_I64(unsigned long, ulong, m);
@@ -444,12 +476,12 @@ __AMDGCN_CLC_SUBGROUP_DOWN_I64(double, double, d);
444476
#define __AMDGCN_CLC_SUBGROUP_DOWN_TO_VEC(TYPE, MANGLED_SCALAR_TY, NUM_ELEMS) \
445477
_CLC_DEF TYPE \
446478
_Z32__spirv_SubgroupShuffleDownINTELIDv##NUM_ELEMS##_##MANGLED_SCALAR_TY##ET_S1_S1_j( \
447-
TYPE var, TYPE lane_delta, unsigned int width) { \
479+
TYPE current, TYPE next, unsigned int delta) { \
448480
TYPE res; \
449481
for (int i = 0; i < NUM_ELEMS; ++i) { \
450482
res[i] = \
451483
_Z32__spirv_SubgroupShuffleDownINTELI##MANGLED_SCALAR_TY##ET_S0_S0_j( \
452-
var[i], (unsigned int)lane_delta[0], width); \
484+
current[i], next[i], delta); \
453485
} \
454486
return res; \
455487
}

libclc/amdgcn-amdhsa/libspirv/workitem/get_max_sub_group_size.cl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,15 @@
88

99
#include <spirv/spirv.h>
1010

11-
// FIXME: Remove the following workaround once the clang change is released.
12-
// This is for backward compatibility with older clang which does not define
13-
// __AMDGCN_WAVEFRONT_SIZE. It does not consider -mwavefrontsize64.
14-
// See:
15-
// https://github.com/intel/llvm/blob/sycl/clang/lib/Basic/Targets/AMDGPU.h#L414
16-
// and:
17-
// https://github.com/intel/llvm/blob/sycl/clang/lib/Basic/Targets/AMDGPU.cpp#L421
18-
#ifndef __AMDGCN_WAVEFRONT_SIZE
19-
#if __gfx1010__ || __gfx1011__ || __gfx1012__ || __gfx1030__ || __gfx1031__
20-
#define __AMDGCN_WAVEFRONT_SIZE 32
21-
#else
22-
#define __AMDGCN_WAVEFRONT_SIZE 64
23-
#endif
24-
#endif
11+
// The clang driver will define this variable depending on the architecture and
12+
// compile flags by linking in ROCm bitcode defining it to true or false. If
13+
// it's 1 the wavefront size used is 64, if it's 0 the wavefront size used is
14+
// 32.
15+
extern constant unsigned char __oclc_wavefrontsize64;
2516

2617
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() {
27-
return __AMDGCN_WAVEFRONT_SIZE;
18+
if (__oclc_wavefrontsize64 == 1) {
19+
return 64;
20+
}
21+
return 32;
2822
}

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -515,24 +515,22 @@ EnableIfNativeShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
515515
}
516516

517517
template <typename T>
518-
EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
518+
EnableIfNativeShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
519519
#ifndef __NVPTX__
520520
using OCLT = detail::ConvertToOpenCLType_t<T>;
521-
return __spirv_SubgroupShuffleDownINTEL(
522-
OCLT(x), OCLT(x), static_cast<uint32_t>(local_id.get(0)));
521+
return __spirv_SubgroupShuffleDownINTEL(OCLT(x), OCLT(x), delta);
523522
#else
524-
return __nvvm_shfl_sync_down_i32(membermask(), x, local_id.get(0), 0x1f);
523+
return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
525524
#endif
526525
}
527526

528527
template <typename T>
529-
EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
528+
EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
530529
#ifndef __NVPTX__
531530
using OCLT = detail::ConvertToOpenCLType_t<T>;
532-
return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x),
533-
static_cast<uint32_t>(local_id.get(0)));
531+
return __spirv_SubgroupShuffleUpINTEL(OCLT(x), OCLT(x), delta);
534532
#else
535-
return __nvvm_shfl_sync_up_i32(membermask(), x, local_id.get(0), 0);
533+
return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
536534
#endif
537535
}
538536

@@ -556,19 +554,19 @@ EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
556554
}
557555

558556
template <typename T>
559-
EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
557+
EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
560558
T result;
561559
for (int s = 0; s < x.get_size(); ++s) {
562-
result[s] = SubgroupShuffleDown(x[s], local_id);
560+
result[s] = SubgroupShuffleDown(x[s], delta);
563561
}
564562
return result;
565563
}
566564

567565
template <typename T>
568-
EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
566+
EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
569567
T result;
570568
for (int s = 0; s < x.get_size(); ++s) {
571-
result[s] = SubgroupShuffleUp(x[s], local_id);
569+
result[s] = SubgroupShuffleUp(x[s], delta);
572570
}
573571
return result;
574572
}
@@ -626,29 +624,26 @@ EnableIfBitcastShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
626624
}
627625

628626
template <typename T>
629-
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
627+
EnableIfBitcastShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
630628
using ShuffleT = ConvertToNativeShuffleType_t<T>;
631629
auto ShuffleX = bit_cast<ShuffleT>(x);
632630
#ifndef __NVPTX__
633-
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(
634-
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
631+
ShuffleT Result = __spirv_SubgroupShuffleDownINTEL(ShuffleX, ShuffleX, delta);
635632
#else
636633
ShuffleT Result =
637-
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, local_id.get(0), 0x1f);
634+
__nvvm_shfl_sync_down_i32(membermask(), ShuffleX, delta, 0x1f);
638635
#endif
639636
return bit_cast<T>(Result);
640637
}
641638

642639
template <typename T>
643-
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
640+
EnableIfBitcastShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
644641
using ShuffleT = ConvertToNativeShuffleType_t<T>;
645642
auto ShuffleX = bit_cast<ShuffleT>(x);
646643
#ifndef __NVPTX__
647-
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(
648-
ShuffleX, ShuffleX, static_cast<uint32_t>(local_id.get(0)));
644+
ShuffleT Result = __spirv_SubgroupShuffleUpINTEL(ShuffleX, ShuffleX, delta);
649645
#else
650-
ShuffleT Result =
651-
__nvvm_shfl_sync_up_i32(membermask(), ShuffleX, local_id.get(0), 0);
646+
ShuffleT Result = __nvvm_shfl_sync_up_i32(membermask(), ShuffleX, delta, 0);
652647
#endif
653648
return bit_cast<T>(Result);
654649
}
@@ -706,29 +701,29 @@ EnableIfGenericShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
706701
}
707702

708703
template <typename T>
709-
EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, id<1> local_id) {
704+
EnableIfGenericShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
710705
T Result;
711706
char *XBytes = reinterpret_cast<char *>(&x);
712707
char *ResultBytes = reinterpret_cast<char *>(&Result);
713708
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
714709
ShuffleChunkT ShuffleX, ShuffleResult;
715710
std::memcpy(&ShuffleX, XBytes + Offset, Size);
716-
ShuffleResult = SubgroupShuffleDown(ShuffleX, local_id);
711+
ShuffleResult = SubgroupShuffleDown(ShuffleX, delta);
717712
std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
718713
};
719714
GenericCall<T>(ShuffleBytes);
720715
return Result;
721716
}
722717

723718
template <typename T>
724-
EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, id<1> local_id) {
719+
EnableIfGenericShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
725720
T Result;
726721
char *XBytes = reinterpret_cast<char *>(&x);
727722
char *ResultBytes = reinterpret_cast<char *>(&Result);
728723
auto ShuffleBytes = [=](size_t Offset, size_t Size) {
729724
ShuffleChunkT ShuffleX, ShuffleResult;
730725
std::memcpy(&ShuffleX, XBytes + Offset, Size);
731-
ShuffleResult = SubgroupShuffleUp(ShuffleX, local_id);
726+
ShuffleResult = SubgroupShuffleUp(ShuffleX, delta);
732727
std::memcpy(ResultBytes + Offset, &ShuffleResult, Size);
733728
};
734729
GenericCall<T>(ShuffleBytes);

0 commit comments

Comments
 (0)