9
9
10
10
#include < sycl/detail/helpers.hpp> // for Builder
11
11
#include < sycl/detail/memcpy.hpp> // detail::memcpy
12
+ #include < sycl/detail/type_traits.hpp> // for is_sub_group
12
13
#include < sycl/exception.hpp> // for errc, exception
13
- #include < sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
14
14
#include < sycl/id.hpp> // for id
15
15
#include < sycl/marray.hpp> // for marray
16
16
#include < sycl/types.hpp> // for vec
@@ -35,26 +35,25 @@ template <typename Group> struct group_scope;
35
35
36
36
} // namespace detail
37
37
38
- // forward decalre sycl::sub_group
39
- struct sub_group ;
40
-
41
38
namespace ext ::oneapi {
42
39
43
- // forward decalre sycl::ext::oneapi::sub_group
44
- struct sub_group ;
40
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
41
+ (__AMDGCN_WAVEFRONT_SIZE == 64 )
42
+ #define BITS_TYPE uint64_t
43
+ #else
44
+ #define BITS_TYPE uint32_t
45
+ #endif
45
46
46
47
// defining `group_ballot` here to make predicate default `true`
47
48
// need to forward declare sub_group_mask first
48
49
struct sub_group_mask ;
49
50
template <typename Group>
50
- std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group> ||
51
- std::is_same_v<std::decay_t <Group>, sycl::sub_group>,
52
- sub_group_mask>
51
+ std::enable_if_t <sycl::detail::is_sub_group<Group>::value, sub_group_mask>
53
52
group_ballot (Group g, bool predicate = true );
54
53
55
54
struct sub_group_mask {
56
55
friend class sycl ::detail::Builder;
57
- using BitsType = uint64_t ;
56
+ using BitsType = BITS_TYPE ;
58
57
59
58
static constexpr size_t max_bits =
60
59
sizeof (BitsType) * CHAR_BIT /* implementation-defined */ ;
@@ -82,8 +81,7 @@ struct sub_group_mask {
82
81
}
83
82
84
83
reference (sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
85
- BitsType one = 1 ;
86
- RefBit = (pos < gmask.bits_num ) ? (one << pos) : 0 ;
84
+ RefBit = (pos < gmask.bits_num ) ? (1UL << pos) : 0 ;
87
85
}
88
86
89
87
private:
@@ -93,36 +91,8 @@ struct sub_group_mask {
93
91
BitsType RefBit;
94
92
};
95
93
96
- #if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
97
- sub_group_mask () : sub_group_mask(0 , GetMaxLocalRangeSize()){};
98
-
99
- sub_group_mask (unsigned long long val)
100
- : sub_group_mask(0 , GetMaxLocalRangeSize()) {
101
- Bits = val;
102
- };
103
-
104
- template <typename T, std::size_t K,
105
- typename = std::enable_if_t <std::is_integral_v<T>>>
106
- sub_group_mask (const sycl::marray<T, K> &val)
107
- : sub_group_mask(0 , GetMaxLocalRangeSize()) {
108
- for (size_t I = 0 , BytesCopied = 0 ; I < K && BytesCopied < sizeof (Bits);
109
- ++I) {
110
- size_t RemainingBytes = sizeof (Bits) - BytesCopied;
111
- size_t BytesToCopy =
112
- RemainingBytes < sizeof (T) ? RemainingBytes : sizeof (T);
113
- sycl::detail::memcpy (reinterpret_cast <char *>(&Bits) + BytesCopied,
114
- &val[I], BytesToCopy);
115
- BytesCopied += BytesToCopy;
116
- }
117
- }
118
-
119
- sub_group_mask (const sub_group_mask &other) = default;
120
- sub_group_mask &operator =(const sub_group_mask &other) = default ;
121
- #endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
122
-
123
94
bool operator [](id<1 > id) const {
124
- BitsType one = 1 ;
125
- return (Bits & ((id.get (0 ) < bits_num) ? (one << id.get (0 )) : 0 ));
95
+ return (Bits & ((id.get (0 ) < bits_num) ? (1UL << id.get (0 )) : 0 ));
126
96
}
127
97
128
98
reference operator [](id<1 > id) { return {*this , id.get (0 )}; }
@@ -284,6 +254,10 @@ struct sub_group_mask {
284
254
return Tmp;
285
255
}
286
256
257
+ sub_group_mask (const sub_group_mask &rhs) = default ;
258
+
259
+ sub_group_mask &operator =(const sub_group_mask &rhs) = default ;
260
+
287
261
template <typename Group>
288
262
friend std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group>,
289
263
sub_group_mask>
@@ -311,14 +285,6 @@ struct sub_group_mask {
311
285
}
312
286
313
287
private:
314
- static size_t GetMaxLocalRangeSize () {
315
- #ifdef __SYCL_DEVICE_ONLY__
316
- return __spirv_SubgroupMaxSize ();
317
- #else
318
- return max_bits;
319
- #endif
320
- }
321
-
322
288
sub_group_mask (BitsType rhs, size_t bn)
323
289
: Bits(rhs & valuable_bits (bn)), bits_num(bn) {
324
290
assert (bits_num <= max_bits);
@@ -336,17 +302,15 @@ struct sub_group_mask {
336
302
};
337
303
338
304
template <typename Group>
339
- std::enable_if_t <std::is_same_v<std::decay_t <Group>, sub_group> ||
340
- std::is_same_v<std::decay_t <Group>, sycl::sub_group>,
341
- sub_group_mask>
305
+ std::enable_if_t <sycl::detail::is_sub_group<Group>::value, sub_group_mask>
342
306
group_ballot (Group g, bool predicate) {
343
307
(void )g;
344
308
#ifdef __SYCL_DEVICE_ONLY__
345
309
auto res = __spirv_GroupNonUniformBallot (
346
310
sycl::detail::spirv::group_scope<Group>::value, predicate);
347
- sub_group_mask::BitsType val = res[0 ];
348
- if constexpr (sizeof (sub_group_mask::BitsType ) == 8 )
349
- val |= ((sub_group_mask::BitsType )res[1 ]) << 32 ;
311
+ BITS_TYPE val = res[0 ];
312
+ if constexpr (sizeof (BITS_TYPE ) == 8 )
313
+ val |= ((BITS_TYPE )res[1 ]) << 32 ;
350
314
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
351
315
val, g.get_max_local_range ()[0 ]);
352
316
#else
@@ -356,6 +320,8 @@ group_ballot(Group g, bool predicate) {
356
320
#endif
357
321
}
358
322
323
+ #undef BITS_TYPE
324
+
359
325
} // namespace ext::oneapi
360
326
} // namespace _V1
361
327
} // namespace sycl
0 commit comments