Skip to content

Commit 8311d79

Browse files
[SYCL] Implement user-defined reduction extension (#7587)
Spec: #7202 Tests: intel/llvm-test-suite#1395
1 parent 8f4b781 commit 8311d79

File tree

6 files changed

+144
-2
lines changed

6 files changed

+144
-2
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
176176
}
177177
template <typename Group, typename T, typename IdT>
178178
EnableIfGenericBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
179-
T Result;
179+
// Initialize with x to support type T without default constructor
180+
T Result = x;
180181
char *XBytes = reinterpret_cast<char *>(&x);
181182
char *ResultBytes = reinterpret_cast<char *>(&Result);
182183
auto BroadcastBytes = [=](size_t Offset, size_t Size) {
@@ -219,7 +220,8 @@ EnableIfGenericBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
219220
if (Dimensions == 1) {
220221
return GroupBroadcast<Group>(x, local_id[0]);
221222
}
222-
T Result;
223+
// Initialize with x to support type T without default constructor
224+
T Result = x;
223225
char *XBytes = reinterpret_cast<char *>(&x);
224226
char *ResultBytes = reinterpret_cast<char *>(&Result);
225227
auto BroadcastBytes = [=](size_t Offset, size_t Size) {

sycl/include/sycl/detail/type_traits.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ template <int Dimensions> class group;
2323
namespace ext {
2424
namespace oneapi {
2525
struct sub_group;
26+
27+
namespace experimental {
28+
template <typename Group, std::size_t Extent> class group_with_scratchpad;
29+
30+
namespace detail {
31+
template <typename T> struct is_group_helper : std::false_type {};
32+
33+
template <typename Group, std::size_t Extent>
34+
struct is_group_helper<group_with_scratchpad<Group, Extent>> : std::true_type {
35+
};
36+
} // namespace detail
37+
} // namespace experimental
2638
} // namespace oneapi
2739
} // namespace ext
2840

@@ -57,6 +69,12 @@ template <class T>
5769
__SYCL_INLINE_CONSTEXPR bool is_group_v =
5870
detail::is_group<T>::value || detail::is_sub_group<T>::value;
5971

72+
namespace ext::oneapi::experimental {
73+
template <class T>
74+
__SYCL_INLINE_CONSTEXPR bool is_group_helper_v =
75+
detail::is_group_helper<std::decay_t<T>>::value;
76+
} // namespace ext::oneapi::experimental
77+
6078
namespace detail {
6179
// Type for Intel device UUID extension.
6280
// For details about this extension, see
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
//==--- user_defined_reductions.hpp -- SYCL ext header file -=--*- C++ -*---==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <sycl/detail/defines.hpp>
12+
#include <sycl/group_algorithm.hpp>
13+
14+
namespace sycl {
15+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
16+
namespace ext::oneapi::experimental {
17+
18+
// ---- reduce_over_group
19+
template <typename GroupHelper, typename T, typename BinaryOperation>
20+
sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T>
21+
reduce_over_group(GroupHelper group_helper, T x, BinaryOperation binary_op) {
22+
if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
23+
return sycl::reduce_over_group(group_helper.get_group(), x, binary_op);
24+
}
25+
#ifdef __SYCL_DEVICE_ONLY__
26+
T *Memory = reinterpret_cast<T *>(group_helper.get_memory().data());
27+
auto g = group_helper.get_group();
28+
Memory[g.get_local_linear_id()] = x;
29+
group_barrier(g);
30+
T result = Memory[0];
31+
if (g.leader()) {
32+
for (int i = 1; i < g.get_local_linear_range(); i++) {
33+
result = binary_op(result, Memory[i]);
34+
}
35+
}
36+
group_barrier(g);
37+
return group_broadcast(g, result);
38+
#else
39+
std::ignore = group_helper;
40+
throw runtime_error("Group algorithms are not supported on host.",
41+
PI_ERROR_INVALID_DEVICE);
42+
#endif
43+
}
44+
45+
template <typename GroupHelper, typename V, typename T,
46+
typename BinaryOperation>
47+
sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper>), T>
48+
reduce_over_group(GroupHelper group_helper, V x, T init,
49+
BinaryOperation binary_op) {
50+
if constexpr (sycl::detail::is_native_op<V, BinaryOperation>::value &&
51+
sycl::detail::is_native_op<T, BinaryOperation>::value) {
52+
return sycl::reduce_over_group(group_helper.get_group(), x, init,
53+
binary_op);
54+
}
55+
#ifdef __SYCL_DEVICE_ONLY__
56+
return binary_op(init, reduce_over_group(group_helper, x, binary_op));
57+
#else
58+
std::ignore = group_helper;
59+
throw runtime_error("Group algorithms are not supported on host.",
60+
PI_ERROR_INVALID_DEVICE);
61+
#endif
62+
}
63+
64+
// ---- joint_reduce
65+
template <typename GroupHelper, typename Ptr, typename BinaryOperation>
66+
sycl::detail::enable_if_t<(is_group_helper_v<GroupHelper> &&
67+
sycl::detail::is_pointer<Ptr>::value),
68+
typename std::iterator_traits<Ptr>::value_type>
69+
joint_reduce(GroupHelper group_helper, Ptr first, Ptr last,
70+
BinaryOperation binary_op) {
71+
if constexpr (sycl::detail::is_native_op<
72+
typename std::iterator_traits<Ptr>::value_type,
73+
BinaryOperation>::value) {
74+
return sycl::joint_reduce(group_helper.get_group(), first, last, binary_op);
75+
}
76+
#ifdef __SYCL_DEVICE_ONLY__
77+
// TODO: the complexity is linear and not logarithmic. Something like
78+
// https://github.com/intel/llvm/blob/8ebd912679f27943d8ef6c33a9775347dce6b80d/sycl/include/sycl/reduction.hpp#L1810-L1818
79+
// might be applicable here.
80+
using T = typename std::iterator_traits<Ptr>::value_type;
81+
auto g = group_helper.get_group();
82+
T partial = *(first + g.get_local_linear_id());
83+
Ptr second = first + g.get_local_linear_range();
84+
sycl::detail::for_each(g, second, last,
85+
[&](const T &x) { partial = binary_op(partial, x); });
86+
group_barrier(g);
87+
return reduce_over_group(group_helper, partial, binary_op);
88+
#else
89+
std::ignore = group_helper;
90+
std::ignore = first;
91+
std::ignore = last;
92+
std::ignore = binary_op;
93+
throw runtime_error("Group algorithms are not supported on host.",
94+
PI_ERROR_INVALID_DEVICE);
95+
#endif
96+
}
97+
98+
template <typename GroupHelper, typename Ptr, typename T,
99+
typename BinaryOperation>
100+
sycl::detail::enable_if_t<
101+
(is_group_helper_v<GroupHelper> && sycl::detail::is_pointer<Ptr>::value), T>
102+
joint_reduce(GroupHelper group_helper, Ptr first, Ptr last, T init,
103+
BinaryOperation binary_op) {
104+
if constexpr (sycl::detail::is_native_op<T, BinaryOperation>::value) {
105+
return sycl::joint_reduce(group_helper.get_group(), first, last, init,
106+
binary_op);
107+
}
108+
#ifdef __SYCL_DEVICE_ONLY__
109+
return binary_op(init, joint_reduce(group_helper, first, last, binary_op));
110+
#else
111+
std::ignore = group_helper;
112+
std::ignore = last;
113+
throw runtime_error("Group algorithms are not supported on host.",
114+
PI_ERROR_INVALID_DEVICE);
115+
#endif
116+
}
117+
} // namespace ext::oneapi::experimental
118+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
119+
} // namespace sycl

sycl/include/sycl/feature_test.hpp.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
6767
#define SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO 3
6868
#define SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY 1
6969
#define SYCL_EXT_ONEAPI_KERNEL_PROPERTIES 1
70+
#define SYCL_EXT_ONEAPI_USER_DEFINED_REDUCTIONS 1
7071
#cmakedefine01 SYCL_BUILD_PI_CUDA
7172
#if SYCL_BUILD_PI_CUDA
7273
#define SYCL_EXT_ONEAPI_BACKEND_CUDA 1

sycl/include/sycl/group_algorithm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <CL/__spirv/spirv_ops.hpp>
1313
#include <CL/__spirv/spirv_types.hpp>
1414
#include <CL/__spirv/spirv_vars.hpp>
15+
#include <sycl/builtins.hpp>
1516
#include <sycl/detail/spirv.hpp>
1617
#include <sycl/detail/type_traits.hpp>
1718
#include <sycl/ext/oneapi/experimental/group_sort.hpp>

sycl/include/sycl/group_barrier.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <sycl/detail/spirv.hpp>
1616
#include <sycl/detail/type_traits.hpp>
1717
#include <sycl/group.hpp>
18+
#include <sycl/sub_group.hpp>
1819

1920
namespace sycl {
2021
__SYCL_INLINE_VER_NAMESPACE(_V1) {

0 commit comments

Comments
 (0)