@@ -134,12 +134,12 @@ struct ReductionOverGroupWithAtomicFunctor
134
134
InputOutputIterIndexerT arg_res_iter_indexer,
135
135
InputRedIndexerT arg_reduced_dims_indexer,
136
136
size_t reduction_size,
137
- size_t iter_gws ,
137
+ size_t iteration_size ,
138
138
size_t reduction_size_per_wi)
139
139
: inp_(data), out_(res), reduction_op_(reduction_op),
140
140
identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
141
141
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
142
- reduction_max_gid_(reduction_size), iter_gws_(iter_gws ),
142
+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size ),
143
143
reductions_per_wi(reduction_size_per_wi)
144
144
{
145
145
}
@@ -528,6 +528,7 @@ struct ReductionOverGroupNoAtomicFunctor
528
528
InputOutputIterIndexerT inp_out_iter_indexer_;
529
529
InputRedIndexerT inp_reduced_dims_indexer_;
530
530
size_t reduction_max_gid_ = 0 ;
531
+ size_t iter_gws_ = 1 ;
531
532
size_t reductions_per_wi = 16 ;
532
533
533
534
public:
@@ -539,22 +540,25 @@ struct ReductionOverGroupNoAtomicFunctor
539
540
InputOutputIterIndexerT arg_res_iter_indexer,
540
541
InputRedIndexerT arg_reduced_dims_indexer,
541
542
size_t reduction_size,
543
+ size_t iteration_size,
542
544
size_t reduction_size_per_wi)
543
545
: inp_(data), out_(res), reduction_op_(reduction_op),
544
546
identity_ (identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
545
547
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
546
- reduction_max_gid_(reduction_size),
548
+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
547
549
reductions_per_wi(reduction_size_per_wi)
548
550
{
549
551
}
550
552
551
- void operator ()(sycl::nd_item<2 > it) const
553
+ void operator ()(sycl::nd_item<1 > it) const
552
554
{
553
555
554
- size_t iter_gid = it.get_global_id (0 );
555
- size_t reduction_batch_id = it.get_group (1 );
556
- size_t reduction_lid = it.get_local_id (1 );
557
- size_t wg = it.get_local_range (1 ); // 0 <= reduction_lid < wg
556
+ const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
557
+ const size_t iter_gid = it.get_global_id (0 ) / red_gws_;
558
+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
559
+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
560
+ const size_t reduction_lid = it.get_local_id (0 );
561
+ const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
558
562
559
563
// work-items sums over input with indices
560
564
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -590,7 +594,7 @@ struct ReductionOverGroupNoAtomicFunctor
590
594
591
595
if (work_group.leader ()) {
592
596
// each group writes to a different memory location
593
- out_[out_iter_offset * it. get_group_range ( 1 ) + reduction_batch_id] =
597
+ out_[out_iter_offset * n_reduction_groups + reduction_batch_id] =
594
598
red_val_over_wg;
595
599
}
596
600
}
@@ -657,20 +661,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
657
661
assert (reduction_groups == 1 );
658
662
659
663
auto globalRange =
660
- sycl::range<2 >{iter_nelems, reduction_groups * wg};
661
- auto localRange = sycl::range<2 >{ 1 , wg};
664
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
665
+ auto localRange = sycl::range<1 >{ wg};
662
666
663
667
using KernelName = class sum_reduction_over_group_temps_krn <
664
668
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
665
669
ReductionIndexerT>;
666
670
cgh.parallel_for <KernelName>(
667
- sycl::nd_range<2 >(globalRange, localRange),
671
+ sycl::nd_range<1 >(globalRange, localRange),
668
672
ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
669
673
InputOutputIterIndexerT,
670
674
ReductionIndexerT>(
671
675
arg_tp, res_tp, ReductionOpT (), identity_val,
672
676
in_out_iter_indexer, reduction_indexer, reduction_nelems,
673
- reductions_per_wi));
677
+ iter_nelems, reductions_per_wi));
674
678
});
675
679
676
680
return comp_ev;
@@ -723,20 +727,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
723
727
reduction_shape_stride};
724
728
725
729
auto globalRange =
726
- sycl::range<2 >{iter_nelems, reduction_groups * wg};
727
- auto localRange = sycl::range<2 >{ 1 , wg};
730
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
731
+ auto localRange = sycl::range<1 >{ wg};
728
732
729
733
using KernelName = class sum_reduction_over_group_temps_krn <
730
734
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
731
735
ReductionIndexerT>;
732
736
cgh.parallel_for <KernelName>(
733
- sycl::nd_range<2 >(globalRange, localRange),
737
+ sycl::nd_range<1 >(globalRange, localRange),
734
738
ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
735
739
InputOutputIterIndexerT,
736
740
ReductionIndexerT>(
737
741
arg_tp, partially_reduced_tmp, ReductionOpT (), identity_val,
738
742
in_out_iter_indexer, reduction_indexer, reduction_nelems,
739
- preferrered_reductions_per_wi));
743
+ iter_nelems, preferrered_reductions_per_wi));
740
744
});
741
745
742
746
size_t remaining_reduction_nelems = reduction_groups;
@@ -778,20 +782,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
778
782
ReductionIndexerT reduction_indexer{};
779
783
780
784
auto globalRange =
781
- sycl::range<2 >{iter_nelems, reduction_groups_ * wg};
782
- auto localRange = sycl::range<2 >{ 1 , wg};
785
+ sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
786
+ auto localRange = sycl::range<1 >{ wg};
783
787
784
788
using KernelName = class sum_reduction_over_group_temps_krn <
785
789
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
786
790
ReductionIndexerT>;
787
791
cgh.parallel_for <KernelName>(
788
- sycl::nd_range<2 >(globalRange, localRange),
792
+ sycl::nd_range<1 >(globalRange, localRange),
789
793
ReductionOverGroupNoAtomicFunctor<
790
794
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
791
795
ReductionIndexerT>(
792
796
temp_arg, temp2_arg, ReductionOpT (), identity_val,
793
797
in_out_iter_indexer, reduction_indexer,
794
- remaining_reduction_nelems,
798
+ remaining_reduction_nelems, iter_nelems,
795
799
preferrered_reductions_per_wi));
796
800
});
797
801
@@ -834,20 +838,21 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
834
838
assert (reduction_groups == 1 );
835
839
836
840
auto globalRange =
837
- sycl::range<2 >{iter_nelems, reduction_groups * wg};
838
- auto localRange = sycl::range<2 >{ 1 , wg};
841
+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
842
+ auto localRange = sycl::range<1 >{ wg};
839
843
840
844
using KernelName = class sum_reduction_over_group_temps_krn <
841
845
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
842
846
ReductionIndexerT>;
843
847
cgh.parallel_for <KernelName>(
844
- sycl::nd_range<2 >(globalRange, localRange),
848
+ sycl::nd_range<1 >(globalRange, localRange),
845
849
ReductionOverGroupNoAtomicFunctor<resTy, resTy, ReductionOpT,
846
850
InputOutputIterIndexerT,
847
851
ReductionIndexerT>(
848
852
temp_arg, res_tp, ReductionOpT (), identity_val,
849
853
in_out_iter_indexer, reduction_indexer,
850
- remaining_reduction_nelems, reductions_per_wi));
854
+ remaining_reduction_nelems, iter_nelems,
855
+ reductions_per_wi));
851
856
});
852
857
853
858
sycl::event cleanup_host_task_event =
0 commit comments