9
9
#pragma once
10
10
11
11
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
12
+ #include " masked_shuffles.hpp"
12
13
13
14
namespace sycl {
14
15
inline namespace _V1 {
@@ -100,87 +101,12 @@ inline __SYCL_ALWAYS_INLINE std::enable_if_t<is_fixed_size_group_v<Group>, T>
100
101
masked_reduction_cuda_shfls (Group g, T x, BinaryOperation binary_op,
101
102
const uint32_t MemberMask) {
102
103
for (int i = g.get_local_range ()[0 ] / 2 ; i > 0 ; i /= 2 ) {
103
- T tmp;
104
- if constexpr (std::is_same_v<T, double >) {
105
- int x_a, x_b;
106
- asm volatile (" mov.b64 {%0,%1},%2;" : " =r" (x_a), " =r" (x_b) : " d" (x));
107
- auto tmp_a = __nvvm_shfl_sync_bfly_i32 (MemberMask, x_a, -1 , i);
108
- auto tmp_b = __nvvm_shfl_sync_bfly_i32 (MemberMask, x_b, -1 , i);
109
- asm volatile (" mov.b64 %0,{%1,%2};" : " =d" (tmp) : " r" (tmp_a), " r" (tmp_b));
110
- } else if constexpr (std::is_same_v<T, long > ||
111
- std::is_same_v<T, unsigned long >) {
112
- int x_a, x_b;
113
- asm volatile (" mov.b64 {%0,%1},%2;" : " =r" (x_a), " =r" (x_b) : " l" (x));
114
- auto tmp_a = __nvvm_shfl_sync_bfly_i32 (MemberMask, x_a, -1 , i);
115
- auto tmp_b = __nvvm_shfl_sync_bfly_i32 (MemberMask, x_b, -1 , i);
116
- asm volatile (" mov.b64 %0,{%1,%2};" : " =l" (tmp) : " r" (tmp_a), " r" (tmp_b));
117
- } else if constexpr (std::is_same_v<T, half>) {
118
- short tmp_b16;
119
- asm volatile (" mov.b16 %0,%1;" : " =h" (tmp_b16) : " h" (x));
120
- auto tmp_b32 = __nvvm_shfl_sync_bfly_i32 (
121
- MemberMask, static_cast <int >(tmp_b16), -1 , i);
122
- asm volatile (" mov.b16 %0,%1;"
123
- : " =h" (tmp)
124
- : " h" (static_cast <short >(tmp_b32)));
125
- } else if constexpr (std::is_same_v<T, float >) {
126
- auto tmp_b32 =
127
- __nvvm_shfl_sync_bfly_i32 (MemberMask, __nvvm_bitcast_f2i (x), -1 , i);
128
- tmp = __nvvm_bitcast_i2f (tmp_b32);
129
- } else {
130
- tmp = __nvvm_shfl_sync_bfly_i32 (MemberMask, x, -1 , i);
131
- }
104
+ T tmp = cuda_shfl_sync_bfly_i32 (MemberMask, x, i, 0x1f );
132
105
x = binary_op (x, tmp);
133
106
}
134
107
return x;
135
108
}
136
109
137
- template <typename Group, typename T>
138
- inline __SYCL_ALWAYS_INLINE std::enable_if_t <
139
- ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
140
- non_uniform_shfl_T (const uint32_t MemberMask, T x, int shfl_param) {
141
- if constexpr (is_fixed_size_group_v<Group>) {
142
- return __nvvm_shfl_sync_up_i32 (MemberMask, x, shfl_param, 0 );
143
- } else {
144
- return __nvvm_shfl_sync_idx_i32 (MemberMask, x, shfl_param, 31 );
145
- }
146
- }
147
-
148
- template <typename Group, typename T>
149
- inline __SYCL_ALWAYS_INLINE std::enable_if_t <
150
- ext::oneapi::experimental::is_user_constructed_group_v<Group>, T>
151
- non_uniform_shfl (Group g, const uint32_t MemberMask, T x, int shfl_param) {
152
- T res;
153
- if constexpr (std::is_same_v<T, double >) {
154
- int x_a, x_b;
155
- asm volatile (" mov.b64 {%0,%1},%2;" : " =r" (x_a), " =r" (x_b) : " d" (x));
156
- auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
157
- auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
158
- asm volatile (" mov.b64 %0,{%1,%2};" : " =d" (res) : " r" (tmp_a), " r" (tmp_b));
159
- } else if constexpr (std::is_same_v<T, long > ||
160
- std::is_same_v<T, unsigned long >) {
161
- int x_a, x_b;
162
- asm volatile (" mov.b64 {%0,%1},%2;" : " =r" (x_a), " =r" (x_b) : " l" (x));
163
- auto tmp_a = non_uniform_shfl_T<Group>(MemberMask, x_a, shfl_param);
164
- auto tmp_b = non_uniform_shfl_T<Group>(MemberMask, x_b, shfl_param);
165
- asm volatile (" mov.b64 %0,{%1,%2};" : " =l" (res) : " r" (tmp_a), " r" (tmp_b));
166
- } else if constexpr (std::is_same_v<T, half>) {
167
- short tmp_b16;
168
- asm volatile (" mov.b16 %0,%1;" : " =h" (tmp_b16) : " h" (x));
169
- auto tmp_b32 = non_uniform_shfl_T<Group>(
170
- MemberMask, static_cast <int >(tmp_b16), shfl_param);
171
- asm volatile (" mov.b16 %0,%1;"
172
- : " =h" (res)
173
- : " h" (static_cast <short >(tmp_b32)));
174
- } else if constexpr (std::is_same_v<T, float >) {
175
- auto tmp_b32 = non_uniform_shfl_T<Group>(MemberMask, __nvvm_bitcast_f2i (x),
176
- shfl_param);
177
- res = __nvvm_bitcast_i2f (tmp_b32);
178
- } else {
179
- res = non_uniform_shfl_T<Group>(MemberMask, x, shfl_param);
180
- }
181
- return res;
182
- }
183
-
184
110
// Opportunistic/Ballot group reduction using shfls
185
111
template <typename Group, typename T, class BinaryOperation >
186
112
inline __SYCL_ALWAYS_INLINE std::enable_if_t <
@@ -207,8 +133,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
207
133
208
134
// __nvvm_fns automatically wraps around to the correct bit position.
209
135
// There is no performance impact on src_set_bit position wrt localSetBit
210
- auto tmp = non_uniform_shfl (g, MemberMask, x,
211
- __nvvm_fns (MemberMask, 0 , unfoldedSrcSetBit));
136
+ T tmp = cuda_shfl_sync_idx_i32 (
137
+ MemberMask, x, __nvvm_fns (MemberMask, 0 , unfoldedSrcSetBit), 31 );
212
138
213
139
if (!(localSetBit == 1 && remainder != 0 )) {
214
140
x = binary_op (x, tmp);
@@ -224,7 +150,8 @@ masked_reduction_cuda_shfls(Group g, T x, BinaryOperation binary_op,
224
150
: " =r" (broadID)
225
151
: " r" (MemberMask));
226
152
227
- return non_uniform_shfl (g, MemberMask, x, broadID);
153
+ x = cuda_shfl_sync_idx_i32 (MemberMask, x, broadID, 31 );
154
+ return x;
228
155
}
229
156
230
157
// Non Redux types must fall back to shfl based implementations.
@@ -265,18 +192,19 @@ inline __SYCL_ALWAYS_INLINE
265
192
return ~0 ;
266
193
}
267
194
268
- #define GET_ID (OP_CHECK, OP ) \
269
- template <typename T, class BinaryOperation > \
270
- inline __SYCL_ALWAYS_INLINE \
271
- std::enable_if_t <OP_CHECK<T, BinaryOperation>::value, T> \
272
- get_identity () { \
273
- return std::numeric_limits<T>::OP (); \
274
- }
275
-
276
- GET_ID (IsMinimum, max)
277
- GET_ID (IsMaximum, min)
195
+ template <typename T, class BinaryOperation >
196
+ inline __SYCL_ALWAYS_INLINE
197
+ std::enable_if_t <IsMinimum<T, BinaryOperation>::value, T>
198
+ get_identity () {
199
+ return std::numeric_limits<T>::min ();
200
+ }
278
201
279
- #undef GET_ID
202
+ template <typename T, class BinaryOperation >
203
+ inline __SYCL_ALWAYS_INLINE
204
+ std::enable_if_t <IsMaximum<T, BinaryOperation>::value, T>
205
+ get_identity () {
206
+ return std::numeric_limits<T>::max ();
207
+ }
280
208
281
209
// // Shuffle based masked reduction impls
282
210
@@ -288,13 +216,12 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
288
216
const uint32_t MemberMask) {
289
217
unsigned localIdVal = g.get_local_id ()[0 ];
290
218
for (int i = 1 ; i < g.get_local_range ()[0 ]; i *= 2 ) {
291
- auto tmp = non_uniform_shfl (g, MemberMask, x, i);
219
+ T tmp = cuda_shfl_sync_up_i32 ( MemberMask, x, i, 0 );
292
220
if (localIdVal >= i)
293
221
x = binary_op (x, tmp);
294
222
}
295
223
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
296
-
297
- x = non_uniform_shfl (g, MemberMask, x, 1 );
224
+ x = cuda_shfl_sync_up_i32 (MemberMask, x, 1 , 0 );
298
225
if (localIdVal == 0 ) {
299
226
return get_identity<T, BinaryOperation>();
300
227
}
@@ -316,14 +243,15 @@ masked_scan_cuda_shfls(Group g, T x, BinaryOperation binary_op,
316
243
for (int i = 1 ; i < g.get_local_range ()[0 ]; i *= 2 ) {
317
244
int unfoldedSrcSetBit = localSetBit - i;
318
245
319
- auto tmp = non_uniform_shfl (g, MemberMask, x,
320
- __nvvm_fns (MemberMask, 0 , unfoldedSrcSetBit));
246
+ T tmp = cuda_shfl_sync_idx_i32 (
247
+ MemberMask, x, __nvvm_fns (MemberMask, 0 , unfoldedSrcSetBit), 31 );
248
+
321
249
if (localIdVal >= i)
322
250
x = binary_op (x, tmp);
323
251
}
324
252
if constexpr (Op == __spv::GroupOperation::ExclusiveScan) {
325
- x = non_uniform_shfl (g, MemberMask, x,
326
- __nvvm_fns (MemberMask, 0 , localSetBit - 1 ));
253
+ x = cuda_shfl_sync_idx_i32 ( MemberMask, x,
254
+ __nvvm_fns (MemberMask, 0 , localSetBit - 1 ), 31 );
327
255
if (localIdVal == 0 ) {
328
256
return get_identity<T, BinaryOperation>();
329
257
}
0 commit comments