@@ -196,10 +196,12 @@ struct ReductionOverGroupWithAtomicFunctor
196
196
if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
197
197
res_ref += red_val_over_wg;
198
198
}
199
- else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
199
+ else if constexpr (std::is_same_v<ReductionOp, sycl::maximum<outT>>)
200
+ {
200
201
res_ref.fetch_max (red_val_over_wg);
201
202
}
202
- else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
203
+ else if constexpr (std::is_same_v<ReductionOp, sycl::minimum<outT>>)
204
+ {
203
205
res_ref.fetch_min (red_val_over_wg);
204
206
}
205
207
else {
@@ -300,22 +302,11 @@ struct CustomReductionOverGroupWithAtomicFunctor
300
302
sycl::memory_scope::device,
301
303
sycl::access::address_space::global_space>
302
304
res_ref (out_[out_iter_offset]);
303
- if constexpr (su_ns::IsPlus<outT, ReductionOp>::value) {
304
- res_ref += red_val_over_wg;
305
- }
306
- else if constexpr (su_ns::IsMaximum<outT, ReductionOp>::value) {
307
- res_ref.fetch_max (red_val_over_wg);
308
- }
309
- else if constexpr (su_ns::IsMinimum<outT, ReductionOp>::value) {
310
- res_ref.fetch_min (red_val_over_wg);
311
- }
312
- else {
313
- outT read_val = res_ref.load ();
314
- outT new_val{};
315
- do {
316
- new_val = reduction_op_ (read_val, red_val_over_wg);
317
- } while (!res_ref.compare_exchange_strong (read_val, new_val));
318
- }
305
+ outT read_val = res_ref.load ();
306
+ outT new_val{};
307
+ do {
308
+ new_val = reduction_op_ (read_val, red_val_over_wg);
309
+ } while (!res_ref.compare_exchange_strong (read_val, new_val));
319
310
}
320
311
}
321
312
};
0 commit comments