@@ -225,6 +225,27 @@ struct known_identity_impl<BinaryOperation, AccumulatorT,
225
225
: std::numeric_limits<AccumulatorT>::lowest();
226
226
};
227
227
228
+ #ifdef __SYCL_REDUCER_OP_EQ_CHECK_TRAIT
229
+ #error "__SYCL_REDUCER_OP_EQ_CHECK_TRAIT must not be defined"
230
+ #endif
231
+
232
+ #define __SYCL_REDUCER_OP_EQ_CHECK_TRAIT (OpName, Op ) \
233
+ template <typename , typename = void > \
234
+ struct HasSameTypeArg ##OpName##Eq : public std::false_type {}; \
235
+ template <typename T> \
236
+ struct HasSameTypeArg ##OpName##Eq< \
237
+ T, std::enable_if_t <std::is_same< \
238
+ decltype (static_cast <T &(T::*)(const T &)>(&T::operator +=)), \
239
+ T &(T::*)(const T &)>::value>> : public std::true_type {};
240
+
241
+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT (Plus, +)
242
+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(Multiplies, *)
243
+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseOR, |)
244
+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseXOR, ^)
245
+ __SYCL_REDUCER_OP_EQ_CHECK_TRAIT(BitwiseAND, &)
246
+
247
+ #undef __SYCL_REDUCER_OP_EQ_CHECK_TRAIT
248
+
228
249
// / Class that is used to represent objects that are passed to user's lambda
229
250
// / functions and representing users' reduction variable.
230
251
// / The generic version of the class represents those reductions of those
@@ -238,6 +259,41 @@ class reducer {
238
259
239
260
T getIdentity () const { return MIdentity; }
240
261
262
+ template <typename _T = T>
263
+ enable_if_t <HasSameTypeArgPlusEq<_T>::value, reducer &>
264
+ operator +=(const _T &Partial) {
265
+ MValue += Partial;
266
+ return *this ;
267
+ }
268
+
269
+ template <typename _T = T>
270
+ enable_if_t <HasSameTypeArgMultipliesEq<_T>::value, reducer &>
271
+ operator *=(const _T &Partial) {
272
+ MValue *= Partial;
273
+ return *this ;
274
+ }
275
+
276
+ template <typename _T = T>
277
+ enable_if_t <HasSameTypeArgBitwiseOREq<_T>::value, reducer &>
278
+ operator |=(const _T &Partial) {
279
+ MValue |= Partial;
280
+ return *this ;
281
+ }
282
+
283
+ template <typename _T = T>
284
+ enable_if_t <HasSameTypeArgBitwiseXOREq<_T>::value, reducer &>
285
+ operator ^=(const _T &Partial) {
286
+ MValue ^= Partial;
287
+ return *this ;
288
+ }
289
+
290
+ template <typename _T = T>
291
+ enable_if_t <HasSameTypeArgBitwiseANDEq<_T>::value, reducer &>
292
+ operator &=(const _T &Partial) {
293
+ MValue &= Partial;
294
+ return *this ;
295
+ }
296
+
241
297
T MValue;
242
298
243
299
private:
0 commit comments