Skip to content

Commit f924aa0

Browse files
authored
[SYCL] Reintroduce sub_group_mask version 2 (#12404)
Was reverted by fb8c82d due to performance regressions on some devices. It turned out that the problem is not with this patch. So applying it again without any changes. Previously reviewed here: #11195
1 parent 24255a5 commit f924aa0

File tree

7 files changed

+179
-28
lines changed

7 files changed

+179
-28
lines changed

sycl/include/sycl/ext/oneapi/experimental/tangle_group.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ get_tangle_group(Group group) {
153153
// TODO: Construct from compiler-generated mask. Return an invalid group in
154154
// in the meantime. CUDA devices will report false for the tangle_group
155155
// support aspect so kernels launch should ensure this is never run.
156-
return tangle_group<sycl::sub_group>(
157-
sycl::detail::Builder::createSubGroupMask<
158-
sycl::ext::oneapi::sub_group_mask>(0, 0));
156+
return tangle_group<sycl::sub_group>(0);
159157
#endif
160158
#else
161159
throw runtime_error("Non-uniform groups are not supported on host device.",

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
#include <sycl/detail/helpers.hpp> // for Builder
1111
#include <sycl/detail/memcpy.hpp> // detail::memcpy
12-
#include <sycl/detail/type_traits.hpp> // for is_sub_group
1312
#include <sycl/exception.hpp> // for errc, exception
13+
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
1414
#include <sycl/id.hpp> // for id
1515
#include <sycl/marray.hpp> // for marray
1616
#include <sycl/types.hpp> // for vec
@@ -35,25 +35,26 @@ template <typename Group> struct group_scope;
3535

3636
} // namespace detail
3737

38+
// forward decalre sycl::sub_group
39+
struct sub_group;
40+
3841
namespace ext::oneapi {
3942

40-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
41-
(__AMDGCN_WAVEFRONT_SIZE == 64)
42-
#define BITS_TYPE uint64_t
43-
#else
44-
#define BITS_TYPE uint32_t
45-
#endif
43+
// forward decalre sycl::ext::oneapi::sub_group
44+
struct sub_group;
4645

4746
// defining `group_ballot` here to make predicate default `true`
4847
// need to forward declare sub_group_mask first
4948
struct sub_group_mask;
5049
template <typename Group>
51-
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
50+
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
51+
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
52+
sub_group_mask>
5253
group_ballot(Group g, bool predicate = true);
5354

5455
struct sub_group_mask {
5556
friend class sycl::detail::Builder;
56-
using BitsType = BITS_TYPE;
57+
using BitsType = uint64_t;
5758

5859
static constexpr size_t max_bits =
5960
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
@@ -81,7 +82,8 @@ struct sub_group_mask {
8182
}
8283

8384
reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
84-
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
85+
BitsType one = 1;
86+
RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
8587
}
8688

8789
private:
@@ -91,8 +93,36 @@ struct sub_group_mask {
9193
BitsType RefBit;
9294
};
9395

96+
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
97+
sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};
98+
99+
sub_group_mask(unsigned long long val)
100+
: sub_group_mask(0, GetMaxLocalRangeSize()) {
101+
Bits = val;
102+
};
103+
104+
template <typename T, std::size_t K,
105+
typename = std::enable_if_t<std::is_integral_v<T>>>
106+
sub_group_mask(const sycl::marray<T, K> &val)
107+
: sub_group_mask(0, GetMaxLocalRangeSize()) {
108+
for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
109+
++I) {
110+
size_t RemainingBytes = sizeof(Bits) - BytesCopied;
111+
size_t BytesToCopy =
112+
RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
113+
sycl::detail::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied,
114+
&val[I], BytesToCopy);
115+
BytesCopied += BytesToCopy;
116+
}
117+
}
118+
119+
sub_group_mask(const sub_group_mask &other) = default;
120+
sub_group_mask &operator=(const sub_group_mask &other) = default;
121+
#endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
122+
94123
bool operator[](id<1> id) const {
95-
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
124+
BitsType one = 1;
125+
return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
96126
}
97127

98128
reference operator[](id<1> id) { return {*this, id.get(0)}; }
@@ -254,10 +284,6 @@ struct sub_group_mask {
254284
return Tmp;
255285
}
256286

257-
sub_group_mask(const sub_group_mask &rhs) = default;
258-
259-
sub_group_mask &operator=(const sub_group_mask &rhs) = default;
260-
261287
template <typename Group>
262288
friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
263289
sub_group_mask>
@@ -285,6 +311,14 @@ struct sub_group_mask {
285311
}
286312

287313
private:
314+
static size_t GetMaxLocalRangeSize() {
315+
#ifdef __SYCL_DEVICE_ONLY__
316+
return __spirv_SubgroupMaxSize();
317+
#else
318+
return max_bits;
319+
#endif
320+
}
321+
288322
sub_group_mask(BitsType rhs, size_t bn)
289323
: Bits(rhs & valuable_bits(bn)), bits_num(bn) {
290324
assert(bits_num <= max_bits);
@@ -302,15 +336,17 @@ struct sub_group_mask {
302336
};
303337

304338
template <typename Group>
305-
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
339+
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
340+
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
341+
sub_group_mask>
306342
group_ballot(Group g, bool predicate) {
307343
(void)g;
308344
#ifdef __SYCL_DEVICE_ONLY__
309345
auto res = __spirv_GroupNonUniformBallot(
310346
sycl::detail::spirv::group_scope<Group>::value, predicate);
311-
BITS_TYPE val = res[0];
312-
if constexpr (sizeof(BITS_TYPE) == 8)
313-
val |= ((BITS_TYPE)res[1]) << 32;
347+
sub_group_mask::BitsType val = res[0];
348+
if constexpr (sizeof(sub_group_mask::BitsType) == 8)
349+
val |= ((sub_group_mask::BitsType)res[1]) << 32;
314350
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
315351
val, g.get_max_local_range()[0]);
316352
#else
@@ -320,8 +356,6 @@ group_ballot(Group g, bool predicate) {
320356
#endif
321357
}
322358

323-
#undef BITS_TYPE
324-
325359
} // namespace ext::oneapi
326360
} // namespace _V1
327361
} // namespace sycl

sycl/include/syclcompat/memory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <utility>
4343

4444
#include <sycl/builtins.hpp>
45+
#include <sycl/ext/intel/experimental/usm_properties.hpp>
4546
#include <sycl/ext/oneapi/group_local_memory.hpp>
4647
#include <sycl/usm.hpp>
4748

sycl/source/feature_test.hpp.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ inline namespace _V1 {
3131
// TODO: Move these feature-test macros to compiler driver.
3232
#define SYCL_EXT_INTEL_DEVICE_INFO 6
3333
#define SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE 1
34-
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
34+
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 2
3535
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
3636
#define SYCL_EXT_ONEAPI_MATRIX 1
3737
#define SYCL_EXT_ONEAPI_ASSERT 1
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <iostream>
5+
#include <sycl/sycl.hpp>
6+
7+
#define TEST_ON_DEVICE(TEST_BODY) \
8+
{ \
9+
sycl::queue queue; \
10+
sycl::buffer<bool, 1> buf(&res, 1); \
11+
queue.submit([&](sycl::handler &h) { \
12+
auto acc = buf.get_access<sycl::access_mode::write>(h); \
13+
h.parallel_for(sycl::nd_range<1>(sycl::range<1>(1), sycl::range<1>(1)), \
14+
[=](sycl::nd_item<1> item) { TEST_BODY }); \
15+
}); \
16+
queue.wait(); \
17+
} \
18+
assert(res);
19+
20+
int main() {
21+
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
22+
using mask_type = sycl::ext::oneapi::sub_group_mask;
23+
24+
// sub_group_mask()
25+
{
26+
mask_type mask;
27+
assert(mask.none() && mask_type::max_bits == mask.size());
28+
29+
bool res = false;
30+
// clang-format off
31+
TEST_ON_DEVICE(
32+
mask_type mask;
33+
auto sg = item.get_sub_group();
34+
acc[0] = mask.none() && (sg.get_max_local_range().size() == mask.size());
35+
)
36+
// clang-format on
37+
}
38+
// sub_group_mask(unsigned long long val)
39+
{
40+
unsigned long long val = 4815162342;
41+
mask_type mask(val);
42+
std::bitset<sizeof(val) * CHAR_BIT> bs(val);
43+
bool res = true;
44+
for (size_t i = 0;
45+
i < std::min(static_cast<size_t>(mask.size()), bs.size()); ++i)
46+
res &= mask[i] == bs[i];
47+
assert(res);
48+
49+
// clang-format off
50+
TEST_ON_DEVICE(
51+
mask_type mask(val);
52+
auto sg = item.get_sub_group();
53+
acc[0] = sg.get_max_local_range().size() == mask.size();
54+
for (size_t i = 0;
55+
i < sycl::min(static_cast<size_t>(mask.size()), bs.size());
56+
++i)
57+
acc[0] &= mask[i] == bs[i];
58+
)
59+
// clang-format on
60+
}
61+
// template <typename T, std::size_t K> sub_group_mask(const sycl::marray<T,
62+
// K>& &val)
63+
{
64+
sycl::marray<char, 4> marr{1, 2, 3, 4};
65+
mask_type mask(marr);
66+
std::bitset<CHAR_BIT> bs[4] = {1, 2, 3, 4};
67+
bool res = true;
68+
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
69+
res &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
70+
assert(res);
71+
72+
// clang-format off
73+
TEST_ON_DEVICE(
74+
mask_type mask(marr);
75+
auto sg = item.get_sub_group();
76+
acc[0] = sg.get_max_local_range().size() == mask.size();
77+
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
78+
acc[0] &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
79+
)
80+
// clang-format on
81+
}
82+
{
83+
// sub_group_mask(const sub_group_mask &other)
84+
unsigned long long val = 4815162342;
85+
mask_type mask1(val);
86+
mask_type mask2(mask1);
87+
assert(mask1 == mask2);
88+
89+
bool res = false;
90+
// clang-format off
91+
TEST_ON_DEVICE(
92+
mask_type mask1(val);
93+
mask_type mask2(mask1);
94+
acc[0] = mask1 == mask2;
95+
)
96+
// clang-format on
97+
}
98+
{
99+
// sub_group_mask& operator=(const sub_group_mask &other)
100+
unsigned long long val = 4815162342;
101+
mask_type mask1(val);
102+
mask_type mask2 = mask1;
103+
assert(mask1 == mask2);
104+
105+
bool res = false;
106+
// clang-format off
107+
TEST_ON_DEVICE(
108+
mask_type mask1(val);
109+
mask_type mask2 = mask1;
110+
acc[0] = mask1 == mask2;
111+
)
112+
// clang-format on
113+
}
114+
#else
115+
std::cout << "Test skipped due to unsupported extension." << std::endl;
116+
#endif
117+
118+
return 0;
119+
}

sycl/test/extensions/macro.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ constexpr bool backend_opencl_macro_defined = true;
99
constexpr bool backend_opencl_macro_defined = false;
1010
#endif
1111

12-
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1
12+
#ifdef SYCL_EXT_ONEAPI_SUB_GROUP_MASK
1313
constexpr bool sub_group_mask_macro_defined = true;
1414
#else
1515
constexpr bool sub_group_mask_macro_defined = false;

sycl/test/extensions/sub_group_mask.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -fsyntax-only %s
22
//
33
// This test is intended to check sycl::ext::oneapi::sub_group_mask interface.
4-
// There is a work in progress update to the spec: intel/llvm#8174
5-
// TODO: udpate this test once revision 2 of the extension is supported
4+
// test for spec ver.2: sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp
65

76
#include <sycl/sycl.hpp>
87

0 commit comments

Comments
 (0)