Skip to content

Commit f2c7869

Browse files
[SYCL] Enforce constraints from sycl_ext_oneapi_reduction_properties (#16238)
1 parent 9a467fa commit f2c7869

File tree

2 files changed

+94
-24
lines changed

2 files changed

+94
-24
lines changed

sycl/include/sycl/ext/oneapi/experimental/reduction_properties.hpp

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ struct initialize_to_identity_key
3232
};
3333
inline constexpr initialize_to_identity_key::value_t initialize_to_identity;
3434

35+
namespace detail {
36+
struct reduction_property_check_anchor {};
37+
} // namespace detail
38+
39+
template <>
40+
struct is_property_key_of<deterministic_key,
41+
detail::reduction_property_check_anchor>
42+
: std::true_type {};
43+
44+
template <>
45+
struct is_property_key_of<initialize_to_identity_key,
46+
detail::reduction_property_check_anchor>
47+
: std::true_type {};
48+
3549
} // namespace experimental
3650
} // namespace oneapi
3751
} // namespace ext
@@ -83,60 +97,88 @@ template <typename BinaryOperation>
8397
struct IsDeterministicOperator<DeterministicOperatorWrapper<BinaryOperation>>
8498
: std::true_type {};
8599

100+
template <typename PropertyList>
101+
inline constexpr bool is_valid_reduction_prop_list =
102+
ext::oneapi::experimental::detail::all_are_properties_of_v<
103+
ext::oneapi::experimental::detail::reduction_property_check_anchor,
104+
PropertyList>;
105+
106+
template <typename BinaryOperation, typename PropertyList, typename... Args>
107+
auto convert_reduction_properties(BinaryOperation combiner,
108+
PropertyList properties, Args &&...args) {
109+
if constexpr (is_valid_reduction_prop_list<PropertyList>) {
110+
auto WrappedOp = WrapOp(combiner, properties);
111+
auto RuntimeProps = GetReductionPropertyList(properties);
112+
return sycl::reduction(std::forward<Args>(args)..., WrappedOp,
113+
RuntimeProps);
114+
} else {
115+
// Invalid, will be disabled by SFINAE at the caller side. Make sure no hard
116+
// error is emitted from here.
117+
}
118+
}
86119
} // namespace detail
87120

88121
template <typename BufferT, typename BinaryOperation, typename PropertyList>
89122
auto reduction(BufferT vars, handler &cgh, BinaryOperation combiner,
90-
PropertyList properties) {
123+
PropertyList properties)
124+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
125+
decltype(detail::convert_reduction_properties(
126+
combiner, properties, vars, cgh))> {
91127
detail::CheckReductionIdentity<typename BufferT::value_type, BinaryOperation>(
92128
properties);
93-
auto WrappedOp = detail::WrapOp(combiner, properties);
94-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
95-
return reduction(vars, cgh, WrappedOp, RuntimeProps);
129+
return detail::convert_reduction_properties(combiner, properties, vars, cgh);
96130
}
97131

98132
template <typename T, typename BinaryOperation, typename PropertyList>
99-
auto reduction(T *var, BinaryOperation combiner, PropertyList properties) {
133+
auto reduction(T *var, BinaryOperation combiner, PropertyList properties)
134+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
135+
decltype(detail::convert_reduction_properties(
136+
combiner, properties, var))> {
100137
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
101-
auto WrappedOp = detail::WrapOp(combiner, properties);
102-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
103-
return reduction(var, WrappedOp, RuntimeProps);
138+
return detail::convert_reduction_properties(combiner, properties, var);
104139
}
105140

106141
template <typename T, size_t Extent, typename BinaryOperation,
107142
typename PropertyList>
108143
auto reduction(span<T, Extent> vars, BinaryOperation combiner,
109-
PropertyList properties) {
144+
PropertyList properties)
145+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
146+
decltype(detail::convert_reduction_properties(
147+
combiner, properties, vars))> {
110148
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
111-
auto WrappedOp = detail::WrapOp(combiner, properties);
112-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
113-
return reduction(vars, WrappedOp, RuntimeProps);
149+
return detail::convert_reduction_properties(combiner, properties, vars);
114150
}
115151

116152
template <typename BufferT, typename BinaryOperation, typename PropertyList>
117153
auto reduction(BufferT vars, handler &cgh,
118154
const typename BufferT::value_type &identity,
119-
BinaryOperation combiner, PropertyList properties) {
120-
auto WrappedOp = detail::WrapOp(combiner, properties);
121-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
122-
return reduction(vars, cgh, identity, WrappedOp, RuntimeProps);
155+
BinaryOperation combiner, PropertyList properties)
156+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
157+
decltype(detail::convert_reduction_properties(
158+
combiner, properties, vars, cgh, identity))> {
159+
return detail::convert_reduction_properties(combiner, properties, vars, cgh,
160+
identity);
123161
}
124162

125163
template <typename T, typename BinaryOperation, typename PropertyList>
126164
auto reduction(T *var, const T &identity, BinaryOperation combiner,
127-
PropertyList properties) {
128-
auto WrappedOp = detail::WrapOp(combiner, properties);
129-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
130-
return reduction(var, identity, WrappedOp, RuntimeProps);
165+
PropertyList properties)
166+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
167+
decltype(detail::convert_reduction_properties(
168+
combiner, properties, var, identity))> {
169+
return detail::convert_reduction_properties(combiner, properties, var,
170+
identity);
131171
}
132172

133173
template <typename T, size_t Extent, typename BinaryOperation,
134174
typename PropertyList>
135175
auto reduction(span<T, Extent> vars, const T &identity,
136-
BinaryOperation combiner, PropertyList properties) {
137-
auto WrappedOp = detail::WrapOp(combiner, properties);
138-
auto RuntimeProps = detail::GetReductionPropertyList(properties);
139-
return reduction(vars, identity, WrappedOp, RuntimeProps);
176+
BinaryOperation combiner, PropertyList properties)
177+
-> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
178+
decltype(detail::convert_reduction_properties(
179+
combiner, properties, vars, identity))> {
180+
return detail::convert_reduction_properties(combiner, properties, vars,
181+
identity);
140182
}
141183

142184
} // namespace _V1
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsyntax-only -Xclang -verify -Xclang -verify-ignore-unexpected=note %s
2+
3+
#include <sycl/sycl.hpp>
4+
5+
int main() {
6+
int *r = nullptr;
7+
// Must not use `sycl_ext_oneapi_reduction_properties`'s overloads:
8+
std::ignore =
9+
sycl::reduction(r, sycl::plus<int>{},
10+
sycl::property::reduction::initialize_to_identity{});
11+
12+
namespace sycl_exp = sycl::ext::oneapi::experimental;
13+
std::ignore =
14+
sycl::reduction(r, sycl::plus<int>{},
15+
sycl_exp::properties(sycl_exp::initialize_to_identity));
16+
17+
// Not a property list:
18+
// expected-error@+2 {{no matching function for call to 'reduction'}}
19+
std::ignore =
20+
sycl::reduction(r, sycl::plus<int>{}, sycl_exp::initialize_to_identity);
21+
22+
// Not a reduction property:
23+
// expected-error@+2 {{no matching function for call to 'reduction'}}
24+
std::ignore =
25+
sycl::reduction(r, sycl::plus<int>{},
26+
sycl_exp::properties(sycl_exp::initialize_to_identity,
27+
sycl_exp::full_group));
28+
}

0 commit comments

Comments
 (0)