Skip to content

Commit a0c3b32

Browse files
authored
[CUDA] Fix cuda group/non-uniform group shuffles. (#13230)
This follows on from discussion of #12705 (comment) to impl/fix non-uniform group shuffles on cuda. - Non-uniform group algorithm impls fixes for permute/left/right - Generalize group shuffles to support double/half/long/short correctly for both uniform and non-uniform groups - Make fixed_size_group test fail if group member "local id" mapping not correct or removed. - Update ballot_group_algorithms.cpp to test previously failing cases on cuda backend. Shuffle impls in ::detail match those in syclomatic for masked shuffle builtins (which don't exist in oneapi outside syclomatic). --------- Signed-off-by: JackAKirk <[email protected]>
1 parent 0939f39 commit a0c3b32

File tree

5 files changed

+136
-121
lines changed

5 files changed

+136
-121
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp> // for IdToMaskPosition
1414

15+
#if defined(__NVPTX__)
16+
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
17+
#endif
18+
1519
#include <sycl/detail/memcpy.hpp> // sycl::detail::memcpy
1620

1721
namespace sycl {
@@ -870,10 +874,10 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
870874
#else
871875
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
872876
GroupT>) {
873-
return __nvvm_shfl_sync_idx_i32(detail::ExtractMask(detail::GetMask(g))[0],
874-
x, LocalId, 0x1f);
877+
return cuda_shfl_sync_idx_i32(detail::ExtractMask(detail::GetMask(g))[0], x,
878+
LocalId, 31);
875879
} else {
876-
return __nvvm_shfl_sync_idx_i32(membermask(), x, LocalId, 0x1f);
880+
return cuda_shfl_sync_idx_i32(membermask(), x, LocalId, 31);
877881
}
878882
#endif
879883
}
@@ -908,12 +912,20 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
908912
#else
909913
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
910914
GroupT>) {
911-
return __nvvm_shfl_sync_bfly_i32(detail::ExtractMask(detail::GetMask(g))[0],
912-
x, static_cast<uint32_t>(mask.get(0)),
913-
0x1f);
914-
} else {
915-
return __nvvm_shfl_sync_bfly_i32(membermask(), x,
915+
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
916+
if constexpr (is_fixed_size_group_v<GroupT>) {
917+
return cuda_shfl_sync_bfly_i32(MemberMask, x,
916918
static_cast<uint32_t>(mask.get(0)), 0x1f);
919+
920+
} else {
921+
int unfoldedSrcSetBit =
922+
(g.get_local_id()[0] ^ static_cast<uint32_t>(mask.get(0))) + 1;
923+
return cuda_shfl_sync_idx_i32(
924+
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
925+
}
926+
} else {
927+
return cuda_shfl_sync_bfly_i32(membermask(), x,
928+
static_cast<uint32_t>(mask.get(0)), 0x1f);
917929
}
918930
#endif
919931
}
@@ -948,10 +960,17 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
948960
#else
949961
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
950962
GroupT>) {
951-
return __nvvm_shfl_sync_down_i32(detail::ExtractMask(detail::GetMask(g))[0],
952-
x, delta, 0x1f);
963+
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
964+
if constexpr (is_fixed_size_group_v<GroupT>) {
965+
return cuda_shfl_sync_down_i32(MemberMask, x, delta, 31);
966+
} else {
967+
unsigned localSetBit = g.get_local_id()[0] + 1;
968+
int unfoldedSrcSetBit = localSetBit + delta;
969+
return cuda_shfl_sync_idx_i32(
970+
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
971+
}
953972
} else {
954-
return __nvvm_shfl_sync_down_i32(membermask(), x, delta, 0x1f);
973+
return cuda_shfl_sync_down_i32(membermask(), x, delta, 31);
955974
}
956975
#endif
957976
}
@@ -985,10 +1004,18 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
9851004
#else
9861005
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
9871006
GroupT>) {
988-
return __nvvm_shfl_sync_up_i32(detail::ExtractMask(detail::GetMask(g))[0],
989-
x, delta, 0);
1007+
auto MemberMask = detail::ExtractMask(detail::GetMask(g))[0];
1008+
if constexpr (is_fixed_size_group_v<GroupT>) {
1009+
return cuda_shfl_sync_up_i32(MemberMask, x, delta, 0);
1010+
} else {
1011+
unsigned localSetBit = g.get_local_id()[0] + 1;
1012+
int unfoldedSrcSetBit = localSetBit - delta;
1013+
1014+
return cuda_shfl_sync_idx_i32(
1015+
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
1016+
}
9901017
} else {
991-
return __nvvm_shfl_sync_up_i32(membermask(), x, delta, 0);
1018+
return cuda_shfl_sync_up_i32(membermask(), x, delta, 0);
9921019
}
9931020
#endif
9941021
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//==--------- masked_shuffles.hpp - cuda masked shuffle algorithms ---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
12+
13+
namespace sycl {
14+
inline namespace _V1 {
15+
namespace detail {
16+
17+
#define CUDA_SHFL_SYNC(SHUFFLE_INSTR) \
18+
template <typename T> \
19+
inline __SYCL_ALWAYS_INLINE T cuda_shfl_sync_##SHUFFLE_INSTR( \
20+
unsigned int mask, T val, unsigned int shfl_param, int c) { \
21+
T res; \
22+
if constexpr (std::is_same_v<T, double>) { \
23+
int x_a, x_b; \
24+
asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(val)); \
25+
auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
26+
auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
27+
asm("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b)); \
28+
} else if constexpr (std::is_same_v<T, long> || \
29+
std::is_same_v<T, unsigned long>) { \
30+
int x_a, x_b; \
31+
asm("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(val)); \
32+
auto tmp_a = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_a, shfl_param, c); \
33+
auto tmp_b = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, x_b, shfl_param, c); \
34+
asm("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b)); \
35+
} else if constexpr (std::is_same_v<T, half>) { \
36+
short tmp_b16; \
37+
asm("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(val)); \
38+
auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
39+
mask, static_cast<int>(tmp_b16), shfl_param, c); \
40+
asm("mov.b16 %0,%1;" : "=h"(res) : "h"(static_cast<short>(tmp_b32))); \
41+
} else if constexpr (std::is_same_v<T, float>) { \
42+
auto tmp_b32 = __nvvm_shfl_sync_##SHUFFLE_INSTR( \
43+
mask, __nvvm_bitcast_f2i(val), shfl_param, c); \
44+
res = __nvvm_bitcast_i2f(tmp_b32); \
45+
} else { \
46+
res = __nvvm_shfl_sync_##SHUFFLE_INSTR(mask, val, shfl_param, c); \
47+
} \
48+
return res; \
49+
}
50+
51+
CUDA_SHFL_SYNC(bfly_i32)
52+
CUDA_SHFL_SYNC(up_i32)
53+
CUDA_SHFL_SYNC(down_i32)
54+
CUDA_SHFL_SYNC(idx_i32)
55+
56+
#undef CUDA_SHFL_SYNC
57+
58+
} // namespace detail
59+
} // namespace _V1
60+
} // namespace sycl
61+
62+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)

sycl/include/sycl/ext/oneapi/experimental/cuda/non_uniform_algorithms.hpp

Lines changed: 25 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
12+
#include "masked_shuffles.hpp"
1213

1314
namespace sycl {
1415
inline namespace _V1 {
@@ -100,87 +101,12 @@ inline __SYCL_ALWAYS_INLINE std::enable_if_t<is_fixed_size_group_v<Group>, T>
100101
masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
101102
const uint32_t MemberMask) {
102103
for (int i = g.get_local_range()[0] / 2; i > 0; i /= 2) {
103-
T tmp;
104-
if constexpr (std::is_same_v<T, double>) {
105-
int x_a, x_b;
106-
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
107-
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
108-
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
109-
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(tmp) : "r"(tmp_a), "r"(tmp_b));
110-
} else if constexpr (std::is_same_v<T, long> ||
111-
std::is_same_v<T, unsigned long>) {
112-
int x_a, x_b;
113-
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
114-
auto tmp_a = __nvvm_shfl_sync_bfly_i32(MemberMask, x_a, -1, i);
115-
auto tmp_b = __nvvm_shfl_sync_bfly_i32(MemberMask, x_b, -1, i);
116-
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(tmp) : "r"(tmp_a), "r"(tmp_b));
117-
} else if constexpr (std::is_same_v<T, half>) {
118-
short tmp_b16;
119-
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
120-
auto tmp_b32 = __nvvm_shfl_sync_bfly_i32(
121-
MemberMask, static_cast<int>(tmp_b16), -1, i);
122-
asm volatile("mov.b16 %0,%1;"
123-
: "=h"(tmp)
124-
: "h"(static_cast<short>(tmp_b32)));
125-
} else if constexpr (std::is_same_v<T, float>) {
126-
auto tmp_b32 =
127-
__nvvm_shfl_sync_bfly_i32(MemberMask, __nvvm_bitcast_f2i(x), -1, i);
128-
tmp = __nvvm_bitcast_i2f(tmp_b32);
129-
} else {
130-
tmp = __nvvm_shfl_sync_bfly_i32(MemberMask, x, -1, i);
131-
}
104+
T tmp = cuda_shfl_sync_bfly_i32(MemberMask, x, i, 0x1f);
132105
x = binary_op(x, tmp);
133106
}
134107
return x;
135108
}
136109

137-
template <typename Group, typename T>
138-
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
139-
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
140-
non_uniform_shfl_T(const uint32_t MemberMask, T x, int shfl_param) {
141-
if constexpr (is_fixed_size_group_v<Group>) {
142-
return __nvvm_shfl_sync_up_i32(MemberMask, x, shfl_param, 0);
143-
} else {
144-
return __nvvm_shfl_sync_idx_i32(MemberMask, x, shfl_param, 31);
145-
}
146-
}
147-
148-
template <typename Group, typename T>
149-
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
150-
ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
151-
non_uniform_shfl(Group g, const uint32_t MemberMask, T x, int shfl_param) {
152-
T res;
153-
if constexpr (std::is_same_v<T, double>) {
154-
int x_a, x_b;
155-
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "d"(x));
156-
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
157-
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
158-
asm volatile("mov.b64 %0,{%1,%2};" : "=d"(res) : "r"(tmp_a), "r"(tmp_b));
159-
} else if constexpr (std::is_same_v<T, long> ||
160-
std::is_same_v<T, unsigned long>) {
161-
int x_a, x_b;
162-
asm volatile("mov.b64 {%0,%1},%2;" : "=r"(x_a), "=r"(x_b) : "l"(x));
163-
auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
164-
auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
165-
asm volatile("mov.b64 %0,{%1,%2};" : "=l"(res) : "r"(tmp_a), "r"(tmp_b));
166-
} else if constexpr (std::is_same_v<T, half>) {
167-
short tmp_b16;
168-
asm volatile("mov.b16 %0,%1;" : "=h"(tmp_b16) : "h"(x));
169-
auto tmp_b32 = non_uniform_shfl_T<Group>(
170-
MemberMask, static_cast<int>(tmp_b16), shfl_param);
171-
asm volatile("mov.b16 %0,%1;"
172-
: "=h"(res)
173-
: "h"(static_cast<short>(tmp_b32)));
174-
} else if constexpr (std::is_same_v<T, float>) {
175-
auto tmp_b32 = non_uniform_shfl_T<Group>(MemberMask, __nvvm_bitcast_f2i(x),
176-
shfl_param);
177-
res = __nvvm_bitcast_i2f(tmp_b32);
178-
} else {
179-
res = non_uniform_shfl_T<Group>(MemberMask, x, shfl_param);
180-
}
181-
return res;
182-
}
183-
184110
// Opportunistic/Ballot group reduction using shfls
185111
template <typename Group, typename T, class BinaryOperation>
186112
inline __SYCL_ALWAYS_INLINE std::enable_if_t<
@@ -207,8 +133,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
207133

208134
// __nvvm_fns automatically wraps around to the correct bit position.
209135
// There is no performance impact on src_set_bit position wrt localSetBit
210-
auto tmp = non_uniform_shfl(g, MemberMask, x,
211-
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
136+
T tmp = cuda_shfl_sync_idx_i32(
137+
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
212138

213139
if (!(localSetBit == 1 && remainder != 0)) {
214140
x = binary_op(x, tmp);
@@ -224,7 +150,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
224150
: "=r"(broadID)
225151
: "r"(MemberMask));
226152

227-
return non_uniform_shfl(g, MemberMask, x, broadID);
153+
x = cuda_shfl_sync_idx_i32(MemberMask, x, broadID, 31);
154+
return x;
228155
}
229156

230157
// Non Redux types must fall back to shfl based implementations.
@@ -265,18 +192,19 @@ inline __SYCL_ALWAYS_INLINE
265192
return ~0;
266193
}
267194

268-
#define GET_ID(OP_CHECK, OP) \
269-
template <typename T, class BinaryOperation> \
270-
inline __SYCL_ALWAYS_INLINE \
271-
std::enable_if_t<OP_CHECK<T, BinaryOperation>::value, T> \
272-
get_identity() { \
273-
return std::numeric_limits<T>::OP(); \
274-
}
275-
276-
GET_ID(IsMinimum, max)
277-
GET_ID(IsMaximum, min)
195+
template <typename T, class BinaryOperation>
196+
inline __SYCL_ALWAYS_INLINE
197+
std::enable_if_t<IsMinimum<T, BinaryOperation>::value, T>
198+
get_identity() {
199+
return std::numeric_limits<T>::min();
200+
}
278201

279-
#undef GET_ID
202+
template <typename T, class BinaryOperation>
203+
inline __SYCL_ALWAYS_INLINE
204+
std::enable_if_t<IsMaximum<T, BinaryOperation>::value, T>
205+
get_identity() {
206+
return std::numeric_limits<T>::max();
207+
}
280208

281209
//// Shuffle based masked reduction impls
282210

@@ -288,13 +216,12 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
288216
const uint32_t MemberMask) {
289217
unsigned localIdVal = g.get_local_id()[0];
290218
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
291-
auto tmp = non_uniform_shfl(g, MemberMask, x, i);
219+
T tmp = cuda_shfl_sync_up_i32(MemberMask, x, i, 0);
292220
if (localIdVal >= i)
293221
x = binary_op(x, tmp);
294222
}
295223
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
296-
297-
x = non_uniform_shfl(g, MemberMask, x, 1);
224+
x = cuda_shfl_sync_up_i32(MemberMask, x, 1, 0);
298225
if (localIdVal == 0) {
299226
return get_identity<T, BinaryOperation>();
300227
}
@@ -316,14 +243,15 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
316243
for (int i = 1; i < g.get_local_range()[0]; i *= 2) {
317244
int unfoldedSrcSetBit = localSetBit - i;
318245

319-
auto tmp = non_uniform_shfl(g, MemberMask, x,
320-
__nvvm_fns(MemberMask, 0, unfoldedSrcSetBit));
246+
T tmp = cuda_shfl_sync_idx_i32(
247+
MemberMask, x, __nvvm_fns(MemberMask, 0, unfoldedSrcSetBit), 31);
248+
321249
if (localIdVal >= i)
322250
x = binary_op(x, tmp);
323251
}
324252
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
325-
x = non_uniform_shfl(g, MemberMask, x,
326-
__nvvm_fns(MemberMask, 0, localSetBit - 1));
253+
x = cuda_shfl_sync_idx_i32(MemberMask, x,
254+
__nvvm_fns(MemberMask, 0, localSetBit - 1), 31);
327255
if (localIdVal == 0) {
328256
return get_identity<T, BinaryOperation>();
329257
}

sycl/test-e2e/NonUniformGroups/ballot_group_algorithms.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,10 @@ int main() {
147147
assert(ReduceAcc[WI] == true);
148148
assert(ExScanAcc[WI] == true);
149149
assert(IncScanAcc[WI] == true);
150-
// TODO: Enable for CUDA devices when issue with shuffles have been
151-
// addressed.
152-
if (Q.get_backend() != sycl::backend::ext_oneapi_cuda) {
153-
assert(ShiftLeftAcc[WI] == true);
154-
assert(ShiftRightAcc[WI] == true);
155-
assert(SelectAcc[WI] == true);
156-
assert(PermuteXorAcc[WI] == true);
157-
}
150+
assert(ShiftLeftAcc[WI] == true);
151+
assert(ShiftRightAcc[WI] == true);
152+
assert(SelectAcc[WI] == true);
153+
assert(PermuteXorAcc[WI] == true);
158154
}
159155
return 0;
160156
}

sycl/test-e2e/NonUniformGroups/fixed_size_group_algorithms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ template <size_t PartitionSize> void test() {
113113
ShiftRightAcc[WI] = (LID < 2 || ShiftRightResult == LID - 2);
114114

115115
uint32_t SelectResult = sycl::select_from_group(
116-
Partition, LID, (Partition.get_local_id() + 2) % PartitionSize);
117-
SelectAcc[WI] = (SelectResult == (LID + 2) % PartitionSize);
116+
Partition, OriginalLID,
117+
(Partition.get_local_id() + 2) % PartitionSize);
118+
SelectAcc[WI] =
119+
SelectResult == OriginalLID - LID + ((LID + 2) % PartitionSize);
118120

119121
uint32_t Mask = PartitionSize <= 2 ? 0 : 2;
120122
uint32_t PermuteXorResult =

0 commit comments

Comments
 (0)