@@ -195,10 +195,12 @@ struct ReductionOverGroupWithAtomicFunctor
195
195
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
196
196
res_ref += red_val_over_wg;
197
197
}
198
- else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
198
+ else if constexpr (std::is_same_v<ReductionOp, sycl::maximum<outT>>)
199
+ {
199
200
res_ref.fetch_max (red_val_over_wg);
200
201
}
201
- else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
202
+ else if constexpr (std::is_same_v<ReductionOp, sycl::minimum<outT>>)
203
+ {
202
204
res_ref.fetch_min (red_val_over_wg);
203
205
}
204
206
else {
@@ -299,22 +301,11 @@ struct CustomReductionOverGroupWithAtomicFunctor
299
301
sycl::memory_scope::device,
300
302
sycl::access::address_space::global_space>
301
303
res_ref (out_[out_iter_offset]);
302
- if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
303
- res_ref += red_val_over_wg;
304
- }
305
- else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
306
- res_ref.fetch_max (red_val_over_wg);
307
- }
308
- else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
309
- res_ref.fetch_min (red_val_over_wg);
310
- }
311
- else {
312
- outT read_val = res_ref.load ();
313
- outT new_val{};
314
- do {
315
- new_val = reduction_op_ (read_val, red_val_over_wg);
316
- } while (!res_ref.compare_exchange_strong (read_val, new_val));
317
- }
304
+ outT read_val = res_ref.load ();
305
+ outT new_val{};
306
+ do {
307
+ new_val = reduction_op_ (read_val, red_val_over_wg);
308
+ } while (!res_ref.compare_exchange_strong (read_val, new_val));
318
309
}
319
310
}
320
311
};
0 commit comments