@@ -32,6 +32,20 @@ struct initialize_to_identity_key
32
32
};
33
33
inline constexpr initialize_to_identity_key::value_t initialize_to_identity;
34
34
35
+ namespace detail {
36
+ struct reduction_property_check_anchor {};
37
+ } // namespace detail
38
+
39
+ template <>
40
+ struct is_property_key_of <deterministic_key,
41
+ detail::reduction_property_check_anchor>
42
+ : std::true_type {};
43
+
44
+ template <>
45
+ struct is_property_key_of <initialize_to_identity_key,
46
+ detail::reduction_property_check_anchor>
47
+ : std::true_type {};
48
+
35
49
} // namespace experimental
36
50
} // namespace oneapi
37
51
} // namespace ext
@@ -83,60 +97,88 @@ template <typename BinaryOperation>
83
97
struct IsDeterministicOperator <DeterministicOperatorWrapper<BinaryOperation>>
84
98
: std::true_type {};
85
99
100
+ template <typename PropertyList>
101
+ inline constexpr bool is_valid_reduction_prop_list =
102
+ ext::oneapi::experimental::detail::all_are_properties_of_v<
103
+ ext::oneapi::experimental::detail::reduction_property_check_anchor,
104
+ PropertyList>;
105
+
106
+ template <typename BinaryOperation, typename PropertyList, typename ... Args>
107
+ auto convert_reduction_properties (BinaryOperation combiner,
108
+ PropertyList properties, Args &&...args) {
109
+ if constexpr (is_valid_reduction_prop_list<PropertyList>) {
110
+ auto WrappedOp = WrapOp (combiner, properties);
111
+ auto RuntimeProps = GetReductionPropertyList (properties);
112
+ return sycl::reduction (std::forward<Args>(args)..., WrappedOp,
113
+ RuntimeProps);
114
+ } else {
115
+ // Invalid, will be disabled by SFINAE at the caller side. Make sure no hard
116
+ // error is emitted from here.
117
+ }
118
+ }
86
119
} // namespace detail
87
120
88
121
template <typename BufferT, typename BinaryOperation, typename PropertyList>
89
122
auto reduction (BufferT vars, handler &cgh, BinaryOperation combiner,
90
- PropertyList properties) {
123
+ PropertyList properties)
124
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
125
+ decltype(detail::convert_reduction_properties(
126
+ combiner, properties, vars, cgh))> {
91
127
detail::CheckReductionIdentity<typename BufferT::value_type, BinaryOperation>(
92
128
properties);
93
- auto WrappedOp = detail::WrapOp (combiner, properties);
94
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
95
- return reduction (vars, cgh, WrappedOp, RuntimeProps);
129
+ return detail::convert_reduction_properties (combiner, properties, vars, cgh);
96
130
}
97
131
98
132
template <typename T, typename BinaryOperation, typename PropertyList>
99
- auto reduction (T *var, BinaryOperation combiner, PropertyList properties) {
133
+ auto reduction (T *var, BinaryOperation combiner, PropertyList properties)
134
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
135
+ decltype(detail::convert_reduction_properties(
136
+ combiner, properties, var))> {
100
137
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
101
- auto WrappedOp = detail::WrapOp (combiner, properties);
102
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
103
- return reduction (var, WrappedOp, RuntimeProps);
138
+ return detail::convert_reduction_properties (combiner, properties, var);
104
139
}
105
140
106
141
template <typename T, size_t Extent, typename BinaryOperation,
107
142
typename PropertyList>
108
143
auto reduction (span<T, Extent> vars, BinaryOperation combiner,
109
- PropertyList properties) {
144
+ PropertyList properties)
145
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
146
+ decltype(detail::convert_reduction_properties(
147
+ combiner, properties, vars))> {
110
148
detail::CheckReductionIdentity<T, BinaryOperation>(properties);
111
- auto WrappedOp = detail::WrapOp (combiner, properties);
112
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
113
- return reduction (vars, WrappedOp, RuntimeProps);
149
+ return detail::convert_reduction_properties (combiner, properties, vars);
114
150
}
115
151
116
152
template <typename BufferT, typename BinaryOperation, typename PropertyList>
117
153
auto reduction (BufferT vars, handler &cgh,
118
154
const typename BufferT::value_type &identity,
119
- BinaryOperation combiner, PropertyList properties) {
120
- auto WrappedOp = detail::WrapOp (combiner, properties);
121
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
122
- return reduction (vars, cgh, identity, WrappedOp, RuntimeProps);
155
+ BinaryOperation combiner, PropertyList properties)
156
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
157
+ decltype(detail::convert_reduction_properties(
158
+ combiner, properties, vars, cgh, identity))> {
159
+ return detail::convert_reduction_properties (combiner, properties, vars, cgh,
160
+ identity);
123
161
}
124
162
125
163
template <typename T, typename BinaryOperation, typename PropertyList>
126
164
auto reduction (T *var, const T &identity, BinaryOperation combiner,
127
- PropertyList properties) {
128
- auto WrappedOp = detail::WrapOp (combiner, properties);
129
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
130
- return reduction (var, identity, WrappedOp, RuntimeProps);
165
+ PropertyList properties)
166
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
167
+ decltype(detail::convert_reduction_properties(
168
+ combiner, properties, var, identity))> {
169
+ return detail::convert_reduction_properties (combiner, properties, var,
170
+ identity);
131
171
}
132
172
133
173
template <typename T, size_t Extent, typename BinaryOperation,
134
174
typename PropertyList>
135
175
auto reduction (span<T, Extent> vars, const T &identity,
136
- BinaryOperation combiner, PropertyList properties) {
137
- auto WrappedOp = detail::WrapOp (combiner, properties);
138
- auto RuntimeProps = detail::GetReductionPropertyList (properties);
139
- return reduction (vars, identity, WrappedOp, RuntimeProps);
176
+ BinaryOperation combiner, PropertyList properties)
177
+ -> std::enable_if_t<detail::is_valid_reduction_prop_list<PropertyList>,
178
+ decltype(detail::convert_reduction_properties(
179
+ combiner, properties, vars, identity))> {
180
+ return detail::convert_reduction_properties (combiner, properties, vars,
181
+ identity);
140
182
}
141
183
142
184
} // namespace _V1
0 commit comments