Skip to content

Commit 2f79acb

Browse files
Replaced nd_range<2> with nd_range<1> for non-atomic case as well
1 parent 7e9f857 commit 2f79acb

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ struct ReductionOverGroupWithAtomicFunctor
134134
InputOutputIterIndexerT arg_res_iter_indexer,
135135
InputRedIndexerT arg_reduced_dims_indexer,
136136
size_t reduction_size,
137-
size_t iter_gws,
137+
size_t iteration_size,
138138
size_t reduction_size_per_wi)
139139
: inp_(data), out_(res), reduction_op_(reduction_op),
140140
identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
141141
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
142-
reduction_max_gid_(reduction_size), iter_gws_(iter_gws),
142+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
143143
reductions_per_wi(reduction_size_per_wi)
144144
{
145145
}
@@ -528,6 +528,7 @@ struct ReductionOverGroupNoAtomicFunctor
528528
InputOutputIterIndexerT inp_out_iter_indexer_;
529529
InputRedIndexerT inp_reduced_dims_indexer_;
530530
size_t reduction_max_gid_ = 0;
531+
size_t iter_gws_ = 1;
531532
size_t reductions_per_wi = 16;
532533

533534
public:
@@ -539,22 +540,25 @@ struct ReductionOverGroupNoAtomicFunctor
539540
InputOutputIterIndexerT arg_res_iter_indexer,
540541
InputRedIndexerT arg_reduced_dims_indexer,
541542
size_t reduction_size,
543+
size_t iteration_size,
542544
size_t reduction_size_per_wi)
543545
: inp_(data), out_(res), reduction_op_(reduction_op),
544546
identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer),
545547
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
546-
reduction_max_gid_(reduction_size),
548+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
547549
reductions_per_wi(reduction_size_per_wi)
548550
{
549551
}
550552

551-
void operator()(sycl::nd_item<2> it) const
553+
void operator()(sycl::nd_item<1> it) const
552554
{
553555

554-
size_t iter_gid = it.get_global_id(0);
555-
size_t reduction_batch_id = it.get_group(1);
556-
size_t reduction_lid = it.get_local_id(1);
557-
size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg
556+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
557+
const size_t iter_gid = it.get_global_id(0) / red_gws_;
558+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
559+
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
560+
const size_t reduction_lid = it.get_local_id(0);
561+
const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg
558562

559563
// work-items sums over input with indices
560564
// inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
@@ -590,7 +594,7 @@ struct ReductionOverGroupNoAtomicFunctor
590594

591595
if (work_group.leader()) {
592596
// each group writes to a different memory location
593-
out_[out_iter_offset * it.get_group_range(1) + reduction_batch_id] =
597+
out_[out_iter_offset * n_reduction_groups + reduction_batch_id] =
594598
red_val_over_wg;
595599
}
596600
}
@@ -657,20 +661,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
657661
assert(reduction_groups == 1);
658662

659663
auto globalRange =
660-
sycl::range<2>{iter_nelems, reduction_groups * wg};
661-
auto localRange = sycl::range<2>{1, wg};
664+
sycl::range<1>{iter_nelems * reduction_groups * wg};
665+
auto localRange = sycl::range<1>{wg};
662666

663667
using KernelName = class sum_reduction_over_group_temps_krn<
664668
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
665669
ReductionIndexerT>;
666670
cgh.parallel_for<KernelName>(
667-
sycl::nd_range<2>(globalRange, localRange),
671+
sycl::nd_range<1>(globalRange, localRange),
668672
ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
669673
InputOutputIterIndexerT,
670674
ReductionIndexerT>(
671675
arg_tp, res_tp, ReductionOpT(), identity_val,
672676
in_out_iter_indexer, reduction_indexer, reduction_nelems,
673-
reductions_per_wi));
677+
iter_nelems, reductions_per_wi));
674678
});
675679

676680
return comp_ev;
@@ -723,20 +727,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
723727
reduction_shape_stride};
724728

725729
auto globalRange =
726-
sycl::range<2>{iter_nelems, reduction_groups * wg};
727-
auto localRange = sycl::range<2>{1, wg};
730+
sycl::range<1>{iter_nelems * reduction_groups * wg};
731+
auto localRange = sycl::range<1>{wg};
728732

729733
using KernelName = class sum_reduction_over_group_temps_krn<
730734
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
731735
ReductionIndexerT>;
732736
cgh.parallel_for<KernelName>(
733-
sycl::nd_range<2>(globalRange, localRange),
737+
sycl::nd_range<1>(globalRange, localRange),
734738
ReductionOverGroupNoAtomicFunctor<argTy, resTy, ReductionOpT,
735739
InputOutputIterIndexerT,
736740
ReductionIndexerT>(
737741
arg_tp, partially_reduced_tmp, ReductionOpT(), identity_val,
738742
in_out_iter_indexer, reduction_indexer, reduction_nelems,
739-
preferrered_reductions_per_wi));
743+
iter_nelems, preferrered_reductions_per_wi));
740744
});
741745

742746
size_t remaining_reduction_nelems = reduction_groups;
@@ -778,20 +782,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
778782
ReductionIndexerT reduction_indexer{};
779783

780784
auto globalRange =
781-
sycl::range<2>{iter_nelems, reduction_groups_ * wg};
782-
auto localRange = sycl::range<2>{1, wg};
785+
sycl::range<1>{iter_nelems * reduction_groups_ * wg};
786+
auto localRange = sycl::range<1>{wg};
783787

784788
using KernelName = class sum_reduction_over_group_temps_krn<
785789
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
786790
ReductionIndexerT>;
787791
cgh.parallel_for<KernelName>(
788-
sycl::nd_range<2>(globalRange, localRange),
792+
sycl::nd_range<1>(globalRange, localRange),
789793
ReductionOverGroupNoAtomicFunctor<
790794
resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
791795
ReductionIndexerT>(
792796
temp_arg, temp2_arg, ReductionOpT(), identity_val,
793797
in_out_iter_indexer, reduction_indexer,
794-
remaining_reduction_nelems,
798+
remaining_reduction_nelems, iter_nelems,
795799
preferrered_reductions_per_wi));
796800
});
797801

@@ -834,20 +838,21 @@ sycl::event sum_reduction_over_group_temps_strided_impl(
834838
assert(reduction_groups == 1);
835839

836840
auto globalRange =
837-
sycl::range<2>{iter_nelems, reduction_groups * wg};
838-
auto localRange = sycl::range<2>{1, wg};
841+
sycl::range<1>{iter_nelems * reduction_groups * wg};
842+
auto localRange = sycl::range<1>{wg};
839843

840844
using KernelName = class sum_reduction_over_group_temps_krn<
841845
argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
842846
ReductionIndexerT>;
843847
cgh.parallel_for<KernelName>(
844-
sycl::nd_range<2>(globalRange, localRange),
848+
sycl::nd_range<1>(globalRange, localRange),
845849
ReductionOverGroupNoAtomicFunctor<resTy, resTy, ReductionOpT,
846850
InputOutputIterIndexerT,
847851
ReductionIndexerT>(
848852
temp_arg, res_tp, ReductionOpT(), identity_val,
849853
in_out_iter_indexer, reduction_indexer,
850-
remaining_reduction_nelems, reductions_per_wi));
854+
remaining_reduction_nelems, iter_nelems,
855+
reductions_per_wi));
851856
});
852857

853858
sycl::event cleanup_host_task_event =

0 commit comments

Comments
 (0)