8
8
#pragma once
9
9
10
10
#include < sycl/detail/helpers.hpp> // for Builder
11
- #include < sycl/detail/type_traits.hpp> // for is_sub_group
12
11
#include < sycl/exception.hpp> // for errc, exception
12
+ #include < sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
13
13
#include < sycl/id.hpp> // for id
14
14
#include < sycl/marray.hpp> // for marray
15
15
#include < sycl/types.hpp> // for vec
16
16
17
17
#include < assert.h> // for assert
18
18
#include < climits> // for CHAR_BIT
19
+ #include < cstring> // for memcpy
19
20
#include < stddef.h> // for size_t
20
21
#include < stdint.h> // for uint32_t
21
22
#include < system_error> // for error_code
@@ -34,25 +35,26 @@ template <typename Group> struct group_scope;
34
35
35
36
} // namespace detail
36
37
38
+ // forward decalre sycl::sub_group
39
+ struct sub_group ;
40
+
37
41
namespace ext ::oneapi {
38
42
39
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
40
- (__AMDGCN_WAVEFRONT_SIZE == 64 )
41
- #define BITS_TYPE uint64_t
42
- #else
43
- #define BITS_TYPE uint32_t
44
- #endif
43
+ // forward decalre sycl::ext::oneapi::sub_group
44
+ struct sub_group ;
45
45
46
46
// defining `group_ballot` here to make predicate default `true`
47
47
// need to forward declare sub_group_mask first
48
48
struct sub_group_mask ;
49
49
template <typename Group>
50
- std::enable_if_t <sycl::detail::is_sub_group<Group>::value, sub_group_mask>
50
+ std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group> ||
51
+ std::is_same_v<std::decay_t <Group>, sycl::sub_group>,
52
+ sub_group_mask>
51
53
group_ballot (Group g, bool predicate = true );
52
54
53
55
struct sub_group_mask {
54
56
friend class sycl ::detail::Builder;
55
- using BitsType = BITS_TYPE ;
57
+ using BitsType = uint64_t ;
56
58
57
59
static constexpr size_t max_bits =
58
60
sizeof (BitsType) * CHAR_BIT /* implementation-defined */ ;
@@ -80,7 +82,8 @@ struct sub_group_mask {
80
82
}
81
83
82
84
reference (sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
83
- RefBit = (pos < gmask.bits_num ) ? (1UL << pos) : 0 ;
85
+ BitsType one = 1 ;
86
+ RefBit = (pos < gmask.bits_num ) ? (one << pos) : 0 ;
84
87
}
85
88
86
89
private:
@@ -90,8 +93,37 @@ struct sub_group_mask {
90
93
BitsType RefBit;
91
94
};
92
95
96
+ #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
97
+ sub_group_mask () : sub_group_mask(0 , GetMaxLocalRangeSize()){};
98
+
99
+ sub_group_mask (unsigned long long val)
100
+ : sub_group_mask(0 , GetMaxLocalRangeSize()) {
101
+ Bits = val;
102
+ };
103
+
104
+ template <typename T, std::size_t K,
105
+ typename = std::enable_if_t <std::is_integral_v<T>>>
106
+ sub_group_mask (const sycl::marray<T, K> &val)
107
+ : sub_group_mask(0 , GetMaxLocalRangeSize()) {
108
+ for (size_t I = 0 , BytesCopied = 0 ; I < K && BytesCopied < sizeof (Bits);
109
+ ++I) {
110
+ size_t RemainingBytes = sizeof (Bits) - BytesCopied;
111
+ size_t BytesToCopy =
112
+ RemainingBytes < sizeof (T) ? RemainingBytes : sizeof (T);
113
+ // TODO: memcpy is not guaranteed to work in kernels. Find alternative.
114
+ std::memcpy (reinterpret_cast <char *>(&Bits) + BytesCopied, &val[I],
115
+ BytesToCopy);
116
+ BytesCopied += BytesToCopy;
117
+ }
118
+ }
119
+
120
+ sub_group_mask (const sub_group_mask &other) = default;
121
+ sub_group_mask &operator =(const sub_group_mask &other) = default ;
122
+ #endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
123
+
93
124
bool operator [](id<1 > id) const {
94
- return (Bits & ((id.get (0 ) < bits_num) ? (1UL << id.get (0 )) : 0 ));
125
+ BitsType one = 1 ;
126
+ return (Bits & ((id.get (0 ) < bits_num) ? (one << id.get (0 )) : 0 ));
95
127
}
96
128
97
129
reference operator [](id<1 > id) { return {*this , id.get (0 )}; }
@@ -253,10 +285,6 @@ struct sub_group_mask {
253
285
return Tmp;
254
286
}
255
287
256
- sub_group_mask (const sub_group_mask &rhs) = default ;
257
-
258
- sub_group_mask &operator =(const sub_group_mask &rhs) = default ;
259
-
260
288
template <typename Group>
261
289
friend std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group>,
262
290
sub_group_mask>
@@ -284,6 +312,14 @@ struct sub_group_mask {
284
312
}
285
313
286
314
private:
315
+ static size_t GetMaxLocalRangeSize () {
316
+ #ifdef __SYCL_DEVICE_ONLY__
317
+ return __spirv_SubgroupMaxSize ();
318
+ #else
319
+ return max_bits;
320
+ #endif
321
+ }
322
+
287
323
sub_group_mask (BitsType rhs, size_t bn)
288
324
: Bits(rhs & valuable_bits (bn)), bits_num(bn) {
289
325
assert (bits_num <= max_bits);
@@ -301,15 +337,17 @@ struct sub_group_mask {
301
337
};
302
338
303
339
template <typename Group>
304
- std::enable_if_t <sycl::detail::is_sub_group<Group>::value, sub_group_mask>
340
+ std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group> ||
341
+ std::is_same_v<std::decay_t <Group>, sycl::sub_group>,
342
+ sub_group_mask>
305
343
group_ballot (Group g, bool predicate) {
306
344
(void )g;
307
345
#ifdef __SYCL_DEVICE_ONLY__
308
346
auto res = __spirv_GroupNonUniformBallot (
309
347
sycl::detail::spirv::group_scope<Group>::value, predicate);
310
- BITS_TYPE val = res[0 ];
311
- if constexpr (sizeof (BITS_TYPE ) == 8 )
312
- val |= ((BITS_TYPE )res[1 ]) << 32 ;
348
+ sub_group_mask::BitsType val = res[0 ];
349
+ if constexpr (sizeof (sub_group_mask::BitsType ) == 8 )
350
+ val |= ((sub_group_mask::BitsType )res[1 ]) << 32 ;
313
351
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
314
352
val, g.get_max_local_range ()[0 ]);
315
353
#else
@@ -319,8 +357,6 @@ group_ballot(Group g, bool predicate) {
319
357
#endif
320
358
}
321
359
322
- #undef BITS_TYPE
323
-
324
360
} // namespace ext::oneapi
325
361
} // namespace _V1
326
362
} // namespace sycl
0 commit comments