Skip to content

Commit 064f332

Browse files
authored
[SYCL] Fix sycl::sub_group by-value semantics (#9984)
According to SYCL2020 sycl::sub_group must follow by-value semantics.
1 parent 5a91df9 commit 064f332

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

sycl/include/sycl/ext/oneapi/sub_group.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,25 @@ struct sub_group {
769769
#endif
770770
}
771771

772+
// Common member functions for by-value semantics
773+
friend bool operator==(const sub_group &lhs, const sub_group &rhs) {
774+
#ifdef __SYCL_DEVICE_ONLY__
775+
return lhs.get_group_id() == rhs.get_group_id();
776+
#else
777+
throw runtime_error("Sub-groups are not supported on host device.",
778+
PI_ERROR_INVALID_DEVICE);
779+
#endif
780+
}
781+
782+
friend bool operator!=(const sub_group &lhs, const sub_group &rhs) {
783+
#ifdef __SYCL_DEVICE_ONLY__
784+
return !(lhs == rhs);
785+
#else
786+
throw runtime_error("Sub-groups are not supported on host device.",
787+
PI_ERROR_INVALID_DEVICE);
788+
#endif
789+
}
790+
772791
protected:
773792
template <int dimensions> friend class sycl::nd_item;
774793
friend sub_group this_sub_group();
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl.hpp>
5+
6+
int main() {
7+
bool result = true;
8+
sycl::queue queue;
9+
{
10+
sycl::buffer<bool, 1> res_buf(&result, 1);
11+
queue.submit([&](sycl::handler &cgh) {
12+
auto res_acc = res_buf.get_access<sycl::access_mode::read_write>(cgh);
13+
14+
cgh.parallel_for<class kernel>(sycl::nd_range<3>({1, 1, 1}, {1, 1, 1}),
15+
[=](sycl::nd_item<3> item) {
16+
sycl::sub_group a = item.get_sub_group();
17+
18+
// check for reflexivity
19+
res_acc[0] &= (a == a);
20+
res_acc[0] &= !(a != a);
21+
22+
// check for symmetry
23+
auto copied = a;
24+
auto &b = copied;
25+
res_acc[0] &= (a == b);
26+
res_acc[0] &= (b == a);
27+
res_acc[0] &= !(a != b);
28+
res_acc[0] &= !(b != a);
29+
30+
// check for transitivity
31+
auto copiedTwice = copied;
32+
const auto &c = copiedTwice;
33+
res_acc[0] &= (c == a);
34+
});
35+
});
36+
queue.wait_and_throw();
37+
}
38+
39+
assert(result);
40+
return 0;
41+
}

0 commit comments

Comments
 (0)