Skip to content

Commit 713a9dd

Browse files
authored
[SYCL] Implement the sycl_khr_group_interface extension (#17595)
Signed-off-by: Michael Aziz <[email protected]>
1 parent 51e8d95 commit 713a9dd

File tree

5 files changed

+608
-0
lines changed

5 files changed

+608
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
//==----- group_interface.hpp --- sycl_khr_group_interface extension -------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#pragma once
9+
10+
#include <sycl/ext/oneapi/free_function_queries.hpp>
11+
#include <sycl/id.hpp>
12+
#include <sycl/range.hpp>
13+
14+
#ifdef __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS
15+
#define SYCL_KHR_GROUP_INTERFACE 1
16+
#endif
17+
18+
#if __cplusplus >= 202302L && defined(__has_include)
19+
#if __has_include(<mdspan>)
20+
#include <mdspan>
21+
#endif
22+
#endif
23+
24+
namespace sycl {
25+
inline namespace _V1 {
26+
#ifdef __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS
27+
namespace khr {
28+
29+
// Forward declarations for traits.
30+
template <int Dimensions> class work_group;
31+
class sub_group;
32+
template <typename ParentGroup> class member_item;
33+
34+
} // namespace khr
35+
36+
namespace detail {
37+
#if defined(__cpp_lib_mdspan)
38+
template <typename IndexType, int Dimensions> struct single_extents;
39+
40+
template <typename IndexType> struct single_extents<IndexType, 1> {
41+
using type = std::extents<IndexType, 1>;
42+
};
43+
44+
template <typename IndexType> struct single_extents<IndexType, 2> {
45+
using type = std::extents<IndexType, 1, 1>;
46+
};
47+
48+
template <typename IndexType> struct single_extents<IndexType, 3> {
49+
using type = std::extents<IndexType, 1, 1, 1>;
50+
};
51+
#endif
52+
53+
template <typename T> struct is_khr_group : public std::false_type {};
54+
55+
template <int Dimensions>
56+
struct is_khr_group<khr::work_group<Dimensions>> : public std::true_type {};
57+
58+
template <> struct is_khr_group<khr::sub_group> : public std::true_type {};
59+
60+
} // namespace detail
61+
62+
namespace khr {
63+
64+
// Forward declaration for friend function.
65+
template <typename ParentGroup>
66+
std::enable_if_t<detail::is_khr_group<ParentGroup>::value,
67+
member_item<ParentGroup>>
68+
get_member_item(ParentGroup g) noexcept;
69+
70+
template <int Dimensions = 1> class work_group {
71+
public:
72+
using id_type = sycl::id<Dimensions>;
73+
using linear_id_type = size_t;
74+
using range_type = sycl::range<Dimensions>;
75+
#if defined(__cpp_lib_mdspan)
76+
using extents_type = std::dextents<size_t, Dimensions>;
77+
#endif
78+
using size_type = size_t;
79+
static constexpr int dimensions = Dimensions;
80+
static constexpr memory_scope fence_scope = memory_scope::work_group;
81+
82+
work_group(group<Dimensions>) noexcept {}
83+
84+
operator group<Dimensions>() const noexcept { return legacy(); }
85+
86+
id_type id() const noexcept { return legacy().get_group_id(); }
87+
88+
linear_id_type linear_id() const noexcept {
89+
return legacy().get_group_linear_id();
90+
}
91+
92+
range_type range() const noexcept { return legacy().get_group_range(); }
93+
94+
#if defined(__cpp_lib_mdspan)
95+
constexpr extents_type extents() const noexcept {
96+
auto LocalRange = legacy().get_local_range();
97+
if constexpr (dimensions == 1) {
98+
return extents_type(LocalRange[0]);
99+
} else if constexpr (dimensions == 2) {
100+
return extents_type(LocalRange[0], LocalRange[1]);
101+
} else if constexpr (dimensions == 3) {
102+
return extents_type(LocalRange[0], LocalRange[1], LocalRange[2]);
103+
}
104+
}
105+
106+
constexpr typename extents_type::index_type
107+
extent(typename extents_type::rank_type r) const noexcept {
108+
return extents().extent(r);
109+
}
110+
111+
static constexpr typename extents_type::rank_type rank() noexcept {
112+
return extents_type::rank();
113+
}
114+
115+
static constexpr typename extents_type::rank_type rank_dynamic() noexcept {
116+
return extents_type::rank_dynamic();
117+
}
118+
119+
static constexpr size_t
120+
static_extent(typename extents_type::rank_type r) noexcept {
121+
return extents_type::static_extent(r);
122+
}
123+
#endif
124+
125+
size_type size() const noexcept { return legacy().get_local_range().size(); }
126+
127+
private:
128+
group<Dimensions> legacy() const noexcept {
129+
return ext::oneapi::this_work_item::get_work_group<Dimensions>();
130+
}
131+
};
132+
133+
class sub_group {
134+
public:
135+
using id_type = sycl::id<1>;
136+
using linear_id_type = uint32_t;
137+
using range_type = sycl::range<1>;
138+
#if defined(__cpp_lib_mdspan)
139+
using extents_type = std::dextents<uint32_t, 1>;
140+
#endif
141+
using size_type = uint32_t;
142+
static constexpr int dimensions = 1;
143+
static constexpr memory_scope fence_scope = memory_scope::sub_group;
144+
145+
sub_group(sycl::sub_group) noexcept {}
146+
147+
operator sycl::sub_group() const noexcept { return legacy(); }
148+
149+
id_type id() const noexcept { return legacy().get_group_id(); }
150+
151+
linear_id_type linear_id() const noexcept {
152+
return legacy().get_group_linear_id();
153+
}
154+
155+
range_type range() const noexcept { return legacy().get_group_range(); }
156+
157+
#if defined(__cpp_lib_mdspan)
158+
extents_type extents() const noexcept {
159+
return extents_type(legacy().get_local_range()[0]);
160+
}
161+
162+
typename extents_type::index_type
163+
extent(typename extents_type::rank_type r) const noexcept {
164+
return extents().extent(r);
165+
}
166+
167+
static constexpr typename extents_type::rank_type rank() noexcept {
168+
return extents_type::rank();
169+
}
170+
171+
static constexpr typename extents_type::rank_type rank_dynamic() noexcept {
172+
return extents_type::rank_dynamic();
173+
}
174+
175+
static constexpr size_t
176+
static_extent(typename extents_type::rank_type r) noexcept {
177+
return extents_type::static_extent(r);
178+
}
179+
#endif
180+
181+
size_type size() const noexcept { return legacy().get_local_range().size(); }
182+
183+
size_type max_size() const noexcept {
184+
return legacy().get_max_local_range().size();
185+
}
186+
187+
private:
188+
sycl::sub_group legacy() const noexcept {
189+
return ext::oneapi::this_work_item::get_sub_group();
190+
}
191+
};
192+
193+
template <typename ParentGroup> class member_item {
194+
public:
195+
using id_type = typename ParentGroup::id_type;
196+
using linear_id_type = typename ParentGroup::linear_id_type;
197+
using range_type = typename ParentGroup::range_type;
198+
#if defined(__cpp_lib_mdspan)
199+
using extents_type = typename detail::single_extents<
200+
typename ParentGroup::extents_type::index_type,
201+
ParentGroup::dimensions>::type;
202+
#endif
203+
using size_type = typename ParentGroup::size_type;
204+
static constexpr int dimensions = ParentGroup::dimensions;
205+
static constexpr memory_scope fence_scope = memory_scope::work_item;
206+
207+
id_type id() const noexcept { return legacy().get_local_id(); }
208+
209+
linear_id_type linear_id() const noexcept {
210+
return legacy().get_local_linear_id();
211+
}
212+
213+
range_type range() const noexcept { return legacy().get_local_range(); }
214+
215+
#if defined(__cpp_lib_mdspan)
216+
constexpr extents_type extents() const noexcept { return extents_type(); }
217+
218+
constexpr typename extents_type::index_type
219+
extent(typename extents_type::rank_type r) const noexcept {
220+
return extents().extent(r);
221+
}
222+
223+
static constexpr typename extents_type::rank_type rank() noexcept {
224+
return extents_type::rank();
225+
}
226+
227+
static constexpr typename extents_type::rank_type rank_dynamic() noexcept {
228+
return extents_type::rank_dynamic();
229+
}
230+
231+
static constexpr size_t
232+
static_extent(typename extents_type::rank_type r) noexcept {
233+
return extents_type::static_extent(r);
234+
}
235+
#endif
236+
237+
constexpr size_type size() const noexcept { return 1; }
238+
239+
private:
240+
auto legacy() const noexcept {
241+
if constexpr (std::is_same_v<ParentGroup, sub_group>) {
242+
return ext::oneapi::this_work_item::get_sub_group();
243+
} else {
244+
return ext::oneapi::this_work_item::get_work_group<
245+
ParentGroup::dimensions>();
246+
}
247+
}
248+
249+
protected:
250+
member_item() {}
251+
252+
friend member_item<ParentGroup>
253+
get_member_item<ParentGroup>(ParentGroup) noexcept;
254+
};
255+
256+
template <typename ParentGroup>
257+
std::enable_if_t<detail::is_khr_group<ParentGroup>::value,
258+
member_item<ParentGroup>>
259+
get_member_item(ParentGroup) noexcept {
260+
return member_item<ParentGroup>{};
261+
}
262+
263+
template <typename Group> bool leader_of(Group g) {
264+
return get_member_item(g).linear_id() == 0;
265+
}
266+
267+
} // namespace khr
268+
#endif
269+
} // namespace _V1
270+
} // namespace sycl

sycl/include/sycl/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,4 @@
123123
#include <sycl/ext/oneapi/virtual_mem/physical_mem.hpp>
124124
#include <sycl/ext/oneapi/virtual_mem/virtual_mem.hpp>
125125
#include <sycl/ext/oneapi/weak_object.hpp>
126+
#include <sycl/khr/group_interface.hpp>
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#define __DPCPP_ENABLE_UNFINISHED_KHR_EXTENSIONS
5+
6+
#include <cassert>
7+
#include <sycl/detail/core.hpp>
8+
#include <sycl/group_algorithm.hpp>
9+
#include <sycl/khr/group_interface.hpp>
10+
11+
using namespace sycl;
12+
13+
void test(queue q) {
14+
int out = 0;
15+
size_t G = 4;
16+
17+
range<2> R(G, G);
18+
{
19+
buffer<int> out_buf(&out, 1);
20+
21+
q.submit([&](handler &cgh) {
22+
auto out = out_buf.template get_access<access::mode::read_write>(cgh);
23+
cgh.parallel_for(nd_range<2>(R, R), [=](nd_item<2> it) {
24+
khr::work_group<2> g = it.get_group();
25+
if (khr::leader_of(g)) {
26+
out[0] += 1;
27+
}
28+
});
29+
});
30+
}
31+
assert(out == 1);
32+
}
33+
34+
int main() {
35+
queue q;
36+
test(q);
37+
return 0;
38+
}

0 commit comments

Comments
 (0)