@@ -122,6 +122,7 @@ struct ReductionOverGroupWithAtomicFunctor
122
122
InputOutputIterIndexerT inp_out_iter_indexer_;
123
123
InputRedIndexerT inp_reduced_dims_indexer_;
124
124
size_t reduction_max_gid_ = 0 ;
125
+ size_t iter_gws_ = 1 ;
125
126
size_t reductions_per_wi = 16 ;
126
127
127
128
public:
@@ -133,22 +134,23 @@ struct ReductionOverGroupWithAtomicFunctor
133
134
InputOutputIterIndexerT arg_res_iter_indexer,
134
135
InputRedIndexerT arg_reduced_dims_indexer,
135
136
size_t reduction_size,
137
+ size_t iter_gws,
136
138
size_t reduction_size_per_wi)
137
139
: inp_(data), out_(res), reduction_op_(reduction_op),
138
140
identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
139
141
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
140
- reduction_max_gid_(reduction_size),
142
+ reduction_max_gid_(reduction_size), iter_gws_(iter_gws),
141
143
reductions_per_wi(reduction_size_per_wi)
142
144
{
143
145
}
144
146
145
- void operator ()(sycl::nd_item<2 > it) const
147
+ void operator ()(sycl::nd_item<1 > it) const
146
148
{
147
-
148
- size_t iter_gid = it.get_global_id (0 );
149
- size_t reduction_batch_id = it. get_group ( 1 );
150
- size_t reduction_lid = it.get_local_id (1 );
151
- size_t wg = it.get_local_range (1 ); // 0 <= reduction_lid < wg
149
+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
150
+ const size_t iter_gid = it.get_global_id (0 ) / red_gws_ ;
151
+ const size_t reduction_batch_id = get_reduction_batch_id (it );
152
+ const size_t reduction_lid = it.get_local_id (0 );
153
+ const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
152
154
153
155
// work-items sums over input with indices
154
156
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -202,6 +204,14 @@ struct ReductionOverGroupWithAtomicFunctor
202
204
}
203
205
}
204
206
}
207
+
208
+ private:
209
+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
210
+ {
211
+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
212
+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
213
+ return reduction_batch_id;
214
+ }
205
215
};
206
216
207
217
typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
@@ -343,21 +353,21 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
343
353
}
344
354
345
355
auto globalRange =
346
- sycl::range<2 >{iter_nelems, reduction_groups * wg};
347
- auto localRange = sycl::range<2 >{ 1 , wg};
356
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
357
+ auto localRange = sycl::range<1 >{ wg};
348
358
349
359
using KernelName = class sum_reduction_over_group_with_atomics_krn <
350
360
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
351
361
ReductionIndexerT>;
352
362
353
363
cgh.parallel_for <KernelName>(
354
- sycl::nd_range<2 >(globalRange, localRange),
364
+ sycl::nd_range<1 >(globalRange, localRange),
355
365
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
356
366
InputOutputIterIndexerT,
357
367
ReductionIndexerT>(
358
368
arg_tp, res_tp, ReductionOpT (), identity_val,
359
369
in_out_iter_indexer, reduction_indexer, reduction_nelems,
360
- reductions_per_wi));
370
+ iter_nelems, reductions_per_wi));
361
371
});
362
372
363
373
return comp_ev;
@@ -480,21 +490,21 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
480
490
}
481
491
482
492
auto globalRange =
483
- sycl::range<2 >{iter_nelems, reduction_groups * wg};
484
- auto localRange = sycl::range<2 >{ 1 , wg};
493
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
494
+ auto localRange = sycl::range<1 >{ wg};
485
495
486
496
using KernelName = class sum_reduction_over_group_with_atomics_krn <
487
497
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
488
498
ReductionIndexerT>;
489
499
490
500
cgh.parallel_for <KernelName>(
491
- sycl::nd_range<2 >(globalRange, localRange),
501
+ sycl::nd_range<1 >(globalRange, localRange),
492
502
ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
493
503
InputOutputIterIndexerT,
494
504
ReductionIndexerT>(
495
505
arg_tp, res_tp, ReductionOpT (), identity_val,
496
506
in_out_iter_indexer, reduction_indexer, reduction_nelems,
497
- reductions_per_wi));
507
+ iter_nelems, reductions_per_wi));
498
508
});
499
509
500
510
return comp_ev;
0 commit comments