@@ -23,9 +23,19 @@ class Builder;
23
23
namespace ext {
24
24
namespace oneapi {
25
25
26
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
27
+ (__AMDGCN_WAVEFRONT_SIZE == 64 )
28
+ #define BITS_TYPE uint64_t
29
+ #else
30
+ #define BITS_TYPE uint32_t
31
+ #endif
32
+
26
33
struct sub_group_mask {
27
34
friend class detail ::Builder;
28
- static constexpr size_t max_bits = 32 /* implementation-defined */ ;
35
+ using BitsType = BITS_TYPE;
36
+
37
+ static constexpr size_t max_bits =
38
+ sizeof (BitsType) * CHAR_BIT /* implementation-defined */ ;
29
39
static constexpr size_t word_size = sizeof (uint32_t ) * CHAR_BIT;
30
40
31
41
// enable reference to individual bit
@@ -55,9 +65,9 @@ struct sub_group_mask {
55
65
56
66
private:
57
67
// Reference to the word containing the bit
58
- uint32_t &Ref;
68
+ BitsType &Ref;
59
69
// Bit mask where only referenced bit is set
60
- uint32_t RefBit;
70
+ BitsType RefBit;
61
71
};
62
72
63
73
bool operator [](id<1 > id) const {
@@ -96,9 +106,9 @@ struct sub_group_mask {
96
106
typename = sycl::detail::enable_if_t <std::is_integral<Type>::value>>
97
107
void insert_bits (Type bits, id<1 > pos = 0 ) {
98
108
size_t insert_size = sizeof (Type) * CHAR_BIT;
99
- uint32_t insert_data = (uint32_t )bits;
109
+ BitsType insert_data = (BitsType )bits;
100
110
insert_data <<= pos.get (0 );
101
- uint32_t mask = 0 ;
111
+ BitsType mask = 0 ;
102
112
if (pos.get (0 ) + insert_size < size ())
103
113
mask |= (valuable_bits (bits_num) << (pos.get (0 ) + insert_size));
104
114
if (pos.get (0 ) < size () && pos.get (0 ))
@@ -108,8 +118,8 @@ struct sub_group_mask {
108
118
}
109
119
110
120
/* The bits are stored in the memory in the following way:
111
- marray id | 0 | 1 | 2 | 3 |
112
- bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|
121
+ marray id | 0 | 1 | 2 | 3 |...
122
+ bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
113
123
*/
114
124
template <typename Type, size_t Size,
115
125
typename = sycl::detail::enable_if_t <std::is_integral<Type>::value>>
@@ -158,7 +168,7 @@ struct sub_group_mask {
158
168
159
169
void set () { Bits = valuable_bits (bits_num); }
160
170
void set (id<1 > id, bool value = true ) { operator [](id) = value; }
161
- void reset () { Bits = uint32_t {0 }; }
171
+ void reset () { Bits = BitsType {0 }; }
162
172
void reset (id<1 > id) { operator [](id) = 0 ; }
163
173
void reset_low () { reset (find_low ()); }
164
174
void reset_high () { reset (find_high ()); }
@@ -240,13 +250,17 @@ struct sub_group_mask {
240
250
}
241
251
242
252
private:
243
- sub_group_mask (uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
253
+ sub_group_mask (BitsType rhs, size_t bn) : Bits(rhs), bits_num(bn) {
244
254
assert (bits_num <= max_bits);
245
255
}
246
- inline uint32_t valuable_bits (size_t bn) const {
247
- return static_cast <uint32_t >((1ULL << bn) - 1ULL );
256
+ inline BitsType valuable_bits (size_t bn) const {
257
+ assert (bn <= max_bits);
258
+ BitsType one = 1 ;
259
+ if (bn == max_bits)
260
+ return -one;
261
+ return (one << bn) - one;
248
262
}
249
- uint32_t Bits;
263
+ BitsType Bits;
250
264
// Number of valuable bits
251
265
size_t bits_num;
252
266
};
@@ -259,15 +273,20 @@ group_ballot(Group g, bool predicate) {
259
273
#ifdef __SYCL_DEVICE_ONLY__
260
274
auto res = __spirv_GroupNonUniformBallot (
261
275
detail::spirv::group_scope<Group>::value, predicate);
276
+ BITS_TYPE val = res[0 ];
277
+ if constexpr (sizeof (BITS_TYPE) == 8 )
278
+ val |= ((BITS_TYPE)res[1 ]) << 32 ;
262
279
return detail::Builder::createSubGroupMask<sub_group_mask>(
263
- res[ 0 ] , g.get_max_local_range ()[0 ]);
280
+ val , g.get_max_local_range ()[0 ]);
264
281
#else
265
282
(void )predicate;
266
283
throw exception{errc::feature_not_supported,
267
284
" Sub-group mask is not supported on host device" };
268
285
#endif
269
286
}
270
287
288
+ #undef BITS_TYPE
289
+
271
290
} // namespace oneapi
272
291
} // namespace ext
273
292
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
0 commit comments