Skip to content

Commit 10d50ed

Browse files
authored
[SYCL] Enable sub group masks for 64 bit subgroups (#7491)
This patch is adding group ballot support for HIP (based on initial work from @abagusetty on #6734 ), but also extending the sub-group mask implementation to support 64 bit masks, as a lot of AMD GPUs use 64 bit wavefronts. Related to issue: #6718
1 parent a578c81 commit 10d50ed

File tree

4 files changed

+68
-15
lines changed

4 files changed

+68
-15
lines changed

libclc/amdgcn-amdhsa/libspirv/SOURCES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
workitem/get_global_offset.ll
3+
group/group_ballot.cl
34
group/collectives.cl
45
group/collectives_helpers.ll
56
atomic/loadstore_helpers.ll
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
#include <spirv/spirv_types.h>
11+
12+
// from llvm/include/llvm/IR/InstrTypes.h
13+
#define ICMP_NE 33
14+
15+
_CLC_DEF _CLC_CONVERGENT __clc_vec4_uint32_t
16+
_Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) {
17+
// only support subgroup for now
18+
if (flag != Subgroup) {
19+
__builtin_trap();
20+
__builtin_unreachable();
21+
}
22+
23+
// prepare result, we only support the ballot operation on 64 threads maximum
24+
// so we only need the first two elements to represent the final mask
25+
__clc_vec4_uint32_t res;
26+
res[2] = 0;
27+
res[3] = 0;
28+
29+
// run the ballot operation
30+
res.xy = __builtin_amdgcn_uicmp((int)predicate, 0, ICMP_NE);
31+
32+
return res;
33+
}

sycl/include/sycl/detail/helpers.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class Builder {
8282
return group<Dims>(Global, Local, Global / Local, Index);
8383
}
8484

85-
template <class ResType>
86-
static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) {
85+
template <class ResType, typename BitsType>
86+
static ResType createSubGroupMask(BitsType Bits, size_t BitsNum) {
8787
return ResType(Bits, BitsNum);
8888
}
8989

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,19 @@ class Builder;
2323
namespace ext {
2424
namespace oneapi {
2525

26+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
27+
(__AMDGCN_WAVEFRONT_SIZE == 64)
28+
#define BITS_TYPE uint64_t
29+
#else
30+
#define BITS_TYPE uint32_t
31+
#endif
32+
2633
struct sub_group_mask {
2734
friend class detail::Builder;
28-
static constexpr size_t max_bits = 32 /* implementation-defined */;
35+
using BitsType = BITS_TYPE;
36+
37+
static constexpr size_t max_bits =
38+
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
2939
static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;
3040

3141
// enable reference to individual bit
@@ -55,9 +65,9 @@ struct sub_group_mask {
5565

5666
private:
5767
// Reference to the word containing the bit
58-
uint32_t &Ref;
68+
BitsType &Ref;
5969
// Bit mask where only referenced bit is set
60-
uint32_t RefBit;
70+
BitsType RefBit;
6171
};
6272

6373
bool operator[](id<1> id) const {
@@ -96,9 +106,9 @@ struct sub_group_mask {
96106
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
97107
void insert_bits(Type bits, id<1> pos = 0) {
98108
size_t insert_size = sizeof(Type) * CHAR_BIT;
99-
uint32_t insert_data = (uint32_t)bits;
109+
BitsType insert_data = (BitsType)bits;
100110
insert_data <<= pos.get(0);
101-
uint32_t mask = 0;
111+
BitsType mask = 0;
102112
if (pos.get(0) + insert_size < size())
103113
mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
104114
if (pos.get(0) < size() && pos.get(0))
@@ -108,8 +118,8 @@ struct sub_group_mask {
108118
}
109119

110120
/* The bits are stored in the memory in the following way:
111-
marray id | 0 | 1 | 2 | 3 |
112-
bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|
121+
marray id | 0 | 1 | 2 | 3 |...
122+
bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
113123
*/
114124
template <typename Type, size_t Size,
115125
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
@@ -158,7 +168,7 @@ struct sub_group_mask {
158168

159169
void set() { Bits = valuable_bits(bits_num); }
160170
void set(id<1> id, bool value = true) { operator[](id) = value; }
161-
void reset() { Bits = uint32_t{0}; }
171+
void reset() { Bits = BitsType{0}; }
162172
void reset(id<1> id) { operator[](id) = 0; }
163173
void reset_low() { reset(find_low()); }
164174
void reset_high() { reset(find_high()); }
@@ -240,13 +250,17 @@ struct sub_group_mask {
240250
}
241251

242252
private:
243-
sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
253+
sub_group_mask(BitsType rhs, size_t bn) : Bits(rhs), bits_num(bn) {
244254
assert(bits_num <= max_bits);
245255
}
246-
inline uint32_t valuable_bits(size_t bn) const {
247-
return static_cast<uint32_t>((1ULL << bn) - 1ULL);
256+
inline BitsType valuable_bits(size_t bn) const {
257+
assert(bn <= max_bits);
258+
BitsType one = 1;
259+
if (bn == max_bits)
260+
return -one;
261+
return (one << bn) - one;
248262
}
249-
uint32_t Bits;
263+
BitsType Bits;
250264
// Number of valuable bits
251265
size_t bits_num;
252266
};
@@ -259,15 +273,20 @@ group_ballot(Group g, bool predicate) {
259273
#ifdef __SYCL_DEVICE_ONLY__
260274
auto res = __spirv_GroupNonUniformBallot(
261275
detail::spirv::group_scope<Group>::value, predicate);
276+
BITS_TYPE val = res[0];
277+
if constexpr (sizeof(BITS_TYPE) == 8)
278+
val |= ((BITS_TYPE)res[1]) << 32;
262279
return detail::Builder::createSubGroupMask<sub_group_mask>(
263-
res[0], g.get_max_local_range()[0]);
280+
val, g.get_max_local_range()[0]);
264281
#else
265282
(void)predicate;
266283
throw exception{errc::feature_not_supported,
267284
"Sub-group mask is not supported on host device"};
268285
#endif
269286
}
270287

288+
#undef BITS_TYPE
289+
271290
} // namespace oneapi
272291
} // namespace ext
273292
} // __SYCL_INLINE_VER_NAMESPACE(_V1)

0 commit comments

Comments
 (0)