Skip to content

Commit 743c35b

Browse files
authored
[SYCL] Implement sycl_ext_oneapi_root_group (#9396)
Implements the extension proposed in `llvm/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_root_group.asciidoc`. The implementation is complete, but the `max_num_work_group_sync` query always returns `1`. The implemented barrier has work group scope since all work items in a root group are currently in the same work group. The two `TODO` annotations (in the query and barrier code) can be updated once the backends include the required functionality. --------- Signed-off-by: Michael Aziz <[email protected]>
1 parent 931203f commit 743c35b

File tree

5 files changed

+280
-1
lines changed

5 files changed

+280
-1
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//==--- root_group.hpp --- SYCL extension for root groups ------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <sycl/builtins.hpp>
12+
#include <sycl/ext/oneapi/properties/properties.hpp>
13+
#include <sycl/memory_enums.hpp>
14+
#include <sycl/queue.hpp>
15+
16+
#define SYCL_EXT_ONEAPI_ROOT_GROUP 1
17+
18+
namespace sycl {
19+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
20+
namespace ext::oneapi::experimental {
21+
22+
namespace info::kernel_queue_specific {
23+
// TODO: Revisit and align with sycl_ext_oneapi_forward_progress extension once
24+
// #7598 is merged.
25+
struct max_num_work_group_sync {
26+
using return_type = size_t;
27+
};
28+
} // namespace info::kernel_queue_specific
29+
30+
struct use_root_sync_key {
31+
using value_t = property_value<use_root_sync_key>;
32+
};
33+
34+
inline constexpr use_root_sync_key::value_t use_root_sync;
35+
36+
template <> struct is_property_key<use_root_sync_key> : std::true_type {};
37+
38+
template <> struct detail::PropertyToKind<use_root_sync_key> {
39+
static constexpr PropKind Kind = PropKind::UseRootSync;
40+
};
41+
42+
template <>
43+
struct detail::IsCompileTimeProperty<use_root_sync_key> : std::true_type {};
44+
45+
template <int Dimensions> class root_group {
46+
public:
47+
using id_type = id<Dimensions>;
48+
using range_type = range<Dimensions>;
49+
using linear_id_type = size_t;
50+
static constexpr int dimensions = Dimensions;
51+
static constexpr memory_scope fence_scope = memory_scope::device;
52+
53+
id<Dimensions> get_group_id() const { return id<Dimensions>{}; };
54+
55+
id<Dimensions> get_local_id() const { return it.get_global_id(); }
56+
57+
range<Dimensions> get_group_range() const {
58+
if constexpr (Dimensions == 3) {
59+
return range<3>{1, 1, 1};
60+
} else if constexpr (Dimensions == 2) {
61+
return range<2>{1, 1};
62+
} else {
63+
return range<1>{1};
64+
}
65+
}
66+
67+
range<Dimensions> get_local_range() const { return it.get_global_range(); };
68+
69+
range<Dimensions> get_max_local_range() const { return get_local_range(); };
70+
71+
size_t get_group_linear_id() const { return 0; };
72+
73+
size_t get_local_linear_id() const { return it.get_global_linear_id(); }
74+
75+
size_t get_group_linear_range() const { return get_group_range().size(); };
76+
77+
size_t get_local_linear_range() const { return get_local_range().size(); };
78+
79+
bool leader() const { return get_local_id() == 0; };
80+
81+
private:
82+
friend root_group<Dimensions>
83+
nd_item<Dimensions>::ext_oneapi_get_root_group() const;
84+
85+
root_group(nd_item<Dimensions> it) : it{it} {}
86+
87+
sycl::nd_item<Dimensions> it;
88+
};
89+
90+
template <int Dimensions>
91+
group<Dimensions> get_child_group(root_group<Dimensions> g) {
92+
(void)g;
93+
return this_group<Dimensions>();
94+
}
95+
96+
template <int Dimensions> sub_group get_child_group(group<Dimensions> g) {
97+
(void)g;
98+
return this_sub_group();
99+
}
100+
101+
namespace this_kernel {
102+
template <int Dimensions> root_group<Dimensions> get_root_group() {
103+
return this_nd_item<Dimensions>().ext_oneapi_get_root_group();
104+
}
105+
} // namespace this_kernel
106+
107+
} // namespace ext::oneapi::experimental
108+
109+
template <>
110+
typename ext::oneapi::experimental::info::kernel_queue_specific::
111+
max_num_work_group_sync::return_type
112+
kernel::ext_oneapi_get_info<
113+
ext::oneapi::experimental::info::kernel_queue_specific::
114+
max_num_work_group_sync>(const queue &q) const {
115+
// TODO: query the backend to return a value >= 1.
116+
return 1;
117+
}
118+
119+
template <int dimensions>
120+
void group_barrier(ext::oneapi::experimental::root_group<dimensions> G,
121+
memory_scope FenceScope = decltype(G)::fence_scope) {
122+
(void)G;
123+
(void)FenceScope;
124+
#ifdef __SYCL_DEVICE_ONLY__
125+
// TODO: Change __spv::Scope::Workgroup to __spv::Scope::Device once backends
126+
// support device scope. __spv::Scope::Workgroup is only valid when
127+
// max_num_work_group_sync is 1, so that all work items in a root group will
128+
// also be in the same work group.
129+
__spirv_ControlBarrier(__spv::Scope::Workgroup, __spv::Scope::Workgroup,
130+
__spv::MemorySemanticsMask::SubgroupMemory |
131+
__spv::MemorySemanticsMask::WorkgroupMemory |
132+
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
133+
#else
134+
throw sycl::runtime_error("Barriers are not supported on host device",
135+
PI_ERROR_INVALID_DEVICE);
136+
#endif
137+
}
138+
139+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
140+
} // namespace sycl

sycl/include/sycl/ext/oneapi/properties/property.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ enum PropKind : uint32_t {
193193
PipeProtocol = 27,
194194
ReadyLatency = 28,
195195
UsesValid = 29,
196+
UseRootSync = 30,
196197
// PropKindSize must always be the last value.
197-
PropKindSize = 30,
198+
PropKindSize = 31,
198199
};
199200

200201
// This trait must be specialized for all properties and must have a unique

sycl/include/sycl/kernel.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace sycl {
2525
__SYCL_INLINE_VER_NAMESPACE(_V1) {
2626
// Forward declaration
2727
class context;
28+
class queue;
2829
template <backend Backend> class backend_traits;
2930
template <bundle_state State> class kernel_bundle;
3031
template <backend BackendName, class SyclObjectT>
@@ -157,6 +158,11 @@ class __SYCL_EXPORT kernel : public detail::OwnerLessBase<kernel> {
157158
typename detail::is_kernel_device_specific_info_desc<Param>::return_type
158159
get_info(const device &Device, const range<3> &WGSize) const;
159160

161+
// TODO: Revisit and align with sycl_ext_oneapi_forward_progress extension
162+
// once #7598 is merged.
163+
template <typename Param>
164+
typename Param::return_type ext_oneapi_get_info(const queue &q) const;
165+
160166
private:
161167
/// Constructs a SYCL kernel object from a valid kernel_impl instance.
162168
kernel(std::shared_ptr<detail::kernel_impl> Impl);

sycl/include/sycl/nd_item.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ namespace detail {
2929
class Builder;
3030
}
3131

32+
namespace ext::oneapi::experimental {
33+
template <int dimensions> class root_group;
34+
}
35+
3236
/// Identifies an instance of the function object executing at each point in an
3337
/// nd_range.
3438
///
@@ -198,6 +202,11 @@ template <int dimensions = 1> class nd_item {
198202
Group.wait_for(events...);
199203
}
200204

205+
sycl::ext::oneapi::experimental::root_group<dimensions>
206+
ext_oneapi_get_root_group() const {
207+
return sycl::ext::oneapi::experimental::root_group<dimensions>{*this};
208+
}
209+
201210
nd_item(const nd_item &rhs) = default;
202211

203212
nd_item(nd_item &&rhs) = default;
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// RUN: %{build} -I . -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <cassert>
5+
#include <cstdlib>
6+
#include <type_traits>
7+
8+
#include <sycl/ext/oneapi/experimental/root_group.hpp>
9+
#include <sycl/sycl.hpp>
10+
11+
static constexpr int WorkGroupSize = 32;
12+
13+
void testFeatureMacro() {
14+
static_assert(SYCL_EXT_ONEAPI_ROOT_GROUP == 1,
15+
"SYCL_EXT_ONEAPI_ROOT_GROUP must have a value of 1");
16+
}
17+
18+
void testQueriesAndProperties() {
19+
sycl::queue q;
20+
const auto bundle =
21+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(q.get_context());
22+
const auto kernel = bundle.get_kernel<class QueryKernel>();
23+
const auto maxWGs = kernel.ext_oneapi_get_info<
24+
sycl::ext::oneapi::experimental::info::kernel_queue_specific::
25+
max_num_work_group_sync>(q);
26+
const auto props = sycl::ext::oneapi::experimental::properties{
27+
sycl::ext::oneapi::experimental::use_root_sync};
28+
q.single_task<class QueryKernel>(props, []() {});
29+
static_assert(std::is_same_v<std::remove_cv<decltype(maxWGs)>::type, size_t>,
30+
"max_num_work_group_sync query must return size_t");
31+
assert(maxWGs >= 1 && "max_num_work_group_sync query failed");
32+
}
33+
34+
void testRootGroup() {
35+
sycl::queue q;
36+
const auto bundle =
37+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(q.get_context());
38+
const auto kernel = bundle.get_kernel<class RootGroupKernel>();
39+
const auto maxWGs = kernel.ext_oneapi_get_info<
40+
sycl::ext::oneapi::experimental::info::kernel_queue_specific::
41+
max_num_work_group_sync>(q);
42+
const auto props = sycl::ext::oneapi::experimental::properties{
43+
sycl::ext::oneapi::experimental::use_root_sync};
44+
45+
int *data = sycl::malloc_shared<int>(maxWGs * WorkGroupSize, q);
46+
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
47+
q.parallel_for<class RootGroupKernel>(range, props, [=](sycl::nd_item<1> it) {
48+
auto root = it.ext_oneapi_get_root_group();
49+
data[root.get_local_id()] = root.get_local_id();
50+
sycl::group_barrier(root);
51+
52+
root = sycl::ext::oneapi::experimental::this_kernel::get_root_group<1>();
53+
int sum = data[root.get_local_id()] +
54+
data[root.get_local_range() - root.get_local_id() - 1];
55+
sycl::group_barrier(root);
56+
data[root.get_local_id()] = sum;
57+
});
58+
q.wait();
59+
60+
const int workItemCount = static_cast<int>(range.get_global_range().size());
61+
for (int i = 0; i < workItemCount; i++) {
62+
assert(data[i] == (workItemCount - 1));
63+
}
64+
}
65+
66+
void testRootGroupFunctions() {
67+
sycl::queue q;
68+
const auto bundle =
69+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(q.get_context());
70+
const auto kernel = bundle.get_kernel<class RootGroupFunctionsKernel>();
71+
const auto maxWGs = kernel.ext_oneapi_get_info<
72+
sycl::ext::oneapi::experimental::info::kernel_queue_specific::
73+
max_num_work_group_sync>(q);
74+
const auto props = sycl::ext::oneapi::experimental::properties{
75+
sycl::ext::oneapi::experimental::use_root_sync};
76+
77+
constexpr int testCount = 10;
78+
bool *testResults = sycl::malloc_shared<bool>(testCount, q);
79+
const auto range = sycl::nd_range<1>{maxWGs * WorkGroupSize, WorkGroupSize};
80+
q.parallel_for<class RootGroupFunctionsKernel>(
81+
range, props, [=](sycl::nd_item<1> it) {
82+
const auto root = it.ext_oneapi_get_root_group();
83+
if (root.leader() || root.get_local_id() == 3) {
84+
testResults[0] = root.get_group_id() == sycl::id<1>(0);
85+
testResults[1] = root.leader()
86+
? root.get_local_id() == sycl::id<1>(0)
87+
: root.get_local_id() == sycl::id<1>(3);
88+
testResults[2] = root.get_group_range() == sycl::range<1>(1);
89+
testResults[3] =
90+
root.get_local_range() == sycl::range<1>(WorkGroupSize);
91+
testResults[4] =
92+
root.get_max_local_range() == sycl::range<1>(WorkGroupSize);
93+
testResults[5] = root.get_group_linear_id() == 0;
94+
testResults[6] =
95+
root.get_local_linear_id() == root.get_local_id().get(0);
96+
testResults[7] = root.get_group_linear_range() == 1;
97+
testResults[8] = root.get_local_linear_range() == WorkGroupSize;
98+
99+
const auto child =
100+
sycl::ext::oneapi::experimental::get_child_group(root);
101+
const auto grandchild =
102+
sycl::ext::oneapi::experimental::get_child_group(child);
103+
testResults[9] = child == it.get_group();
104+
static_assert(
105+
std::is_same_v<std::remove_cv<decltype(grandchild)>::type,
106+
sycl::sub_group>,
107+
"get_child_group(sycl::group) must return a sycl::sub_group");
108+
}
109+
});
110+
q.wait();
111+
112+
for (int i = 0; i < testCount; i++) {
113+
assert(testResults[i]);
114+
}
115+
}
116+
117+
int main() {
118+
testFeatureMacro();
119+
testQueriesAndProperties();
120+
testRootGroup();
121+
testRootGroupFunctions();
122+
return EXIT_SUCCESS;
123+
}

0 commit comments

Comments
 (0)