Skip to content

Commit 3bd09b9

Browse files
[SYCL] Implement sub_group_mask version 2 (#11195)
https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_oneapi_sub_group_mask.asciidoc --------- Signed-off-by: Larsen, Steffen <[email protected]> Co-authored-by: Larsen, Steffen <[email protected]>
1 parent 68af512 commit 3bd09b9

File tree

6 files changed

+180
-25
lines changed

6 files changed

+180
-25
lines changed

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
#pragma once
99

1010
#include <sycl/detail/helpers.hpp> // for Builder
11-
#include <sycl/detail/type_traits.hpp> // for is_sub_group
1211
#include <sycl/exception.hpp> // for errc, exception
12+
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
1313
#include <sycl/id.hpp> // for id
1414
#include <sycl/marray.hpp> // for marray
1515
#include <sycl/types.hpp> // for vec
1616

1717
#include <assert.h> // for assert
1818
#include <climits> // for CHAR_BIT
19+
#include <cstring> // for memcpy
1920
#include <stddef.h> // for size_t
2021
#include <stdint.h> // for uint32_t
2122
#include <system_error> // for error_code
@@ -34,25 +35,26 @@ template <typename Group> struct group_scope;
3435

3536
} // namespace detail
3637

38+
// forward decalre sycl::sub_group
39+
struct sub_group;
40+
3741
namespace ext::oneapi {
3842

39-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
40-
(__AMDGCN_WAVEFRONT_SIZE == 64)
41-
#define BITS_TYPE uint64_t
42-
#else
43-
#define BITS_TYPE uint32_t
44-
#endif
43+
// forward decalre sycl::ext::oneapi::sub_group
44+
struct sub_group;
4545

4646
// defining `group_ballot` here to make predicate default `true`
4747
// need to forward declare sub_group_mask first
4848
struct sub_group_mask;
4949
template <typename Group>
50-
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
50+
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
51+
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
52+
sub_group_mask>
5153
group_ballot(Group g, bool predicate = true);
5254

5355
struct sub_group_mask {
5456
friend class sycl::detail::Builder;
55-
using BitsType = BITS_TYPE;
57+
using BitsType = uint64_t;
5658

5759
static constexpr size_t max_bits =
5860
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
@@ -80,7 +82,8 @@ struct sub_group_mask {
8082
}
8183

8284
reference(sub_group_mask &gmask, size_t pos) : Ref(gmask.Bits) {
83-
RefBit = (pos < gmask.bits_num) ? (1UL << pos) : 0;
85+
BitsType one = 1;
86+
RefBit = (pos < gmask.bits_num) ? (one << pos) : 0;
8487
}
8588

8689
private:
@@ -90,8 +93,37 @@ struct sub_group_mask {
9093
BitsType RefBit;
9194
};
9295

96+
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
97+
sub_group_mask() : sub_group_mask(0, GetMaxLocalRangeSize()){};
98+
99+
sub_group_mask(unsigned long long val)
100+
: sub_group_mask(0, GetMaxLocalRangeSize()) {
101+
Bits = val;
102+
};
103+
104+
template <typename T, std::size_t K,
105+
typename = std::enable_if_t<std::is_integral_v<T>>>
106+
sub_group_mask(const sycl::marray<T, K> &val)
107+
: sub_group_mask(0, GetMaxLocalRangeSize()) {
108+
for (size_t I = 0, BytesCopied = 0; I < K && BytesCopied < sizeof(Bits);
109+
++I) {
110+
size_t RemainingBytes = sizeof(Bits) - BytesCopied;
111+
size_t BytesToCopy =
112+
RemainingBytes < sizeof(T) ? RemainingBytes : sizeof(T);
113+
// TODO: memcpy is not guaranteed to work in kernels. Find alternative.
114+
std::memcpy(reinterpret_cast<char *>(&Bits) + BytesCopied, &val[I],
115+
BytesToCopy);
116+
BytesCopied += BytesToCopy;
117+
}
118+
}
119+
120+
sub_group_mask(const sub_group_mask &other) = default;
121+
sub_group_mask &operator=(const sub_group_mask &other) = default;
122+
#endif // SYCL_EXT_ONEAPI_SUB_GROUP_MASK
123+
93124
bool operator[](id<1> id) const {
94-
return (Bits & ((id.get(0) < bits_num) ? (1UL << id.get(0)) : 0));
125+
BitsType one = 1;
126+
return (Bits & ((id.get(0) < bits_num) ? (one << id.get(0)) : 0));
95127
}
96128

97129
reference operator[](id<1> id) { return {*this, id.get(0)}; }
@@ -253,10 +285,6 @@ struct sub_group_mask {
253285
return Tmp;
254286
}
255287

256-
sub_group_mask(const sub_group_mask &rhs) = default;
257-
258-
sub_group_mask &operator=(const sub_group_mask &rhs) = default;
259-
260288
template <typename Group>
261289
friend std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group>,
262290
sub_group_mask>
@@ -284,6 +312,14 @@ struct sub_group_mask {
284312
}
285313

286314
private:
315+
static size_t GetMaxLocalRangeSize() {
316+
#ifdef __SYCL_DEVICE_ONLY__
317+
return __spirv_SubgroupMaxSize();
318+
#else
319+
return max_bits;
320+
#endif
321+
}
322+
287323
sub_group_mask(BitsType rhs, size_t bn)
288324
: Bits(rhs & valuable_bits(bn)), bits_num(bn) {
289325
assert(bits_num <= max_bits);
@@ -301,15 +337,17 @@ struct sub_group_mask {
301337
};
302338

303339
template <typename Group>
304-
std::enable_if_t<sycl::detail::is_sub_group<Group>::value, sub_group_mask>
340+
std::enable_if_t<std::is_same_v<std::decay_t<Group>, sub_group> ||
341+
std::is_same_v<std::decay_t<Group>, sycl::sub_group>,
342+
sub_group_mask>
305343
group_ballot(Group g, bool predicate) {
306344
(void)g;
307345
#ifdef __SYCL_DEVICE_ONLY__
308346
auto res = __spirv_GroupNonUniformBallot(
309347
sycl::detail::spirv::group_scope<Group>::value, predicate);
310-
BITS_TYPE val = res[0];
311-
if constexpr (sizeof(BITS_TYPE) == 8)
312-
val |= ((BITS_TYPE)res[1]) << 32;
348+
sub_group_mask::BitsType val = res[0];
349+
if constexpr (sizeof(sub_group_mask::BitsType) == 8)
350+
val |= ((sub_group_mask::BitsType)res[1]) << 32;
313351
return sycl::detail::Builder::createSubGroupMask<sub_group_mask>(
314352
val, g.get_max_local_range()[0]);
315353
#else
@@ -319,8 +357,6 @@ group_ballot(Group g, bool predicate) {
319357
#endif
320358
}
321359

322-
#undef BITS_TYPE
323-
324360
} // namespace ext::oneapi
325361
} // namespace _V1
326362
} // namespace sycl

sycl/include/syclcompat/memory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <utility>
4343

4444
#include <sycl/builtins.hpp>
45+
#include <sycl/ext/intel/experimental/usm_properties.hpp>
4546
#include <sycl/ext/oneapi/group_local_memory.hpp>
4647
#include <sycl/usm.hpp>
4748

sycl/source/feature_test.hpp.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ inline namespace _V1 {
3131
// TODO: Move these feature-test macros to compiler driver.
3232
#define SYCL_EXT_INTEL_DEVICE_INFO 6
3333
#define SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE 1
34-
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 1
34+
#define SYCL_EXT_ONEAPI_SUB_GROUP_MASK 2
3535
#define SYCL_EXT_ONEAPI_LOCAL_MEMORY 1
3636
#define SYCL_EXT_ONEAPI_MATRIX 1
3737
#define SYCL_EXT_ONEAPI_ASSERT 1
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <iostream>
5+
#include <sycl/sycl.hpp>
6+
7+
#define TEST_ON_DEVICE(TEST_BODY) \
8+
{ \
9+
sycl::queue queue; \
10+
sycl::buffer<bool, 1> buf(&res, 1); \
11+
queue.submit([&](sycl::handler &h) { \
12+
auto acc = buf.get_access<sycl::access_mode::write>(h); \
13+
h.parallel_for(sycl::nd_range<1>(sycl::range<1>(1), sycl::range<1>(1)), \
14+
[=](sycl::nd_item<1> item) { TEST_BODY }); \
15+
}); \
16+
queue.wait(); \
17+
} \
18+
assert(res);
19+
20+
int main() {
21+
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK >= 2
22+
using mask_type = sycl::ext::oneapi::sub_group_mask;
23+
24+
// sub_group_mask()
25+
{
26+
mask_type mask;
27+
assert(mask.none() && mask_type::max_bits == mask.size());
28+
29+
bool res = false;
30+
// clang-format off
31+
TEST_ON_DEVICE(
32+
mask_type mask;
33+
auto sg = item.get_sub_group();
34+
acc[0] = mask.none() && (sg.get_max_local_range().size() == mask.size());
35+
)
36+
// clang-format on
37+
}
38+
// sub_group_mask(unsigned long long val)
39+
{
40+
unsigned long long val = 4815162342;
41+
mask_type mask(val);
42+
std::bitset<sizeof(val) * CHAR_BIT> bs(val);
43+
bool res = true;
44+
for (size_t i = 0;
45+
i < std::min(static_cast<size_t>(mask.size()), bs.size()); ++i)
46+
res &= mask[i] == bs[i];
47+
assert(res);
48+
49+
// clang-format off
50+
TEST_ON_DEVICE(
51+
mask_type mask(val);
52+
auto sg = item.get_sub_group();
53+
acc[0] = sg.get_max_local_range().size() == mask.size();
54+
for (size_t i = 0;
55+
i < sycl::min(static_cast<size_t>(mask.size()), bs.size());
56+
++i)
57+
acc[0] &= mask[i] == bs[i];
58+
)
59+
// clang-format on
60+
}
61+
// template <typename T, std::size_t K> sub_group_mask(const sycl::marray<T,
62+
// K>& &val)
63+
{
64+
sycl::marray<char, 4> marr{1, 2, 3, 4};
65+
mask_type mask(marr);
66+
std::bitset<CHAR_BIT> bs[4] = {1, 2, 3, 4};
67+
bool res = true;
68+
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
69+
res &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
70+
assert(res);
71+
72+
// clang-format off
73+
TEST_ON_DEVICE(
74+
mask_type mask(marr);
75+
auto sg = item.get_sub_group();
76+
acc[0] = sg.get_max_local_range().size() == mask.size();
77+
for (size_t i = 0; i < mask.size() && (i / CHAR_BIT) < 4; ++i)
78+
acc[0] &= mask[i] == bs[i / CHAR_BIT][i % CHAR_BIT];
79+
)
80+
// clang-format on
81+
}
82+
{
83+
// sub_group_mask(const sub_group_mask &other)
84+
unsigned long long val = 4815162342;
85+
mask_type mask1(val);
86+
mask_type mask2(mask1);
87+
assert(mask1 == mask2);
88+
89+
bool res = false;
90+
// clang-format off
91+
TEST_ON_DEVICE(
92+
mask_type mask1(val);
93+
mask_type mask2(mask1);
94+
acc[0] = mask1 == mask2;
95+
)
96+
// clang-format on
97+
}
98+
{
99+
// sub_group_mask& operator=(const sub_group_mask &other)
100+
unsigned long long val = 4815162342;
101+
mask_type mask1(val);
102+
mask_type mask2 = mask1;
103+
assert(mask1 == mask2);
104+
105+
bool res = false;
106+
// clang-format off
107+
TEST_ON_DEVICE(
108+
mask_type mask1(val);
109+
mask_type mask2 = mask1;
110+
acc[0] = mask1 == mask2;
111+
)
112+
// clang-format on
113+
}
114+
#else
115+
std::cout << "Test skipped due to unsupported extension." << std::endl;
116+
#endif
117+
118+
return 0;
119+
}

sycl/test/extensions/macro.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ constexpr bool backend_opencl_macro_defined = true;
99
constexpr bool backend_opencl_macro_defined = false;
1010
#endif
1111

12-
#if SYCL_EXT_ONEAPI_SUB_GROUP_MASK == 1
12+
#ifdef SYCL_EXT_ONEAPI_SUB_GROUP_MASK
1313
constexpr bool sub_group_mask_macro_defined = true;
1414
#else
1515
constexpr bool sub_group_mask_macro_defined = false;

sycl/test/extensions/sub_group_mask.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// RUN: %clangxx -fsycl -fsycl-device-only -fsyntax-only %s
22
//
33
// This test is intended to check sycl::ext::oneapi::sub_group_mask interface.
4-
// There is a work in progress update to the spec: intel/llvm#8174
5-
// TODO: udpate this test once revision 2 of the extension is supported
4+
// test for spec ver.2: sycl/test-e2e/SubGroupMask/sub_group_mask_ver2.cpp
65

76
#include <sycl/sycl.hpp>
87

0 commit comments

Comments
 (0)