Skip to content

Commit 1e85b1e

Browse files
committed
Changed WG traversal pattern in boolean reductions
Similar to changes in sum, now traverses the iteration dimension the fastest
1 parent 83fff33 commit 1e85b1e

File tree

1 file changed

+10
-26
lines changed

1 file changed

+10
-26
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "pybind11/pybind11.h"
3535

3636
#include "utils/offset_utils.hpp"
37+
#include "utils/sycl_utils.hpp"
3738
#include "utils/type_dispatch.hpp"
3839
#include "utils/type_utils.hpp"
3940

@@ -227,9 +228,8 @@ struct ContigBooleanReduction
227228

228229
void operator()(sycl::nd_item<1> it) const
229230
{
230-
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
231-
const size_t reduction_id = it.get_global_id(0) / red_gws_;
232-
const size_t reduction_batch_id = get_reduction_batch_id(it);
231+
const size_t reduction_id = it.get_group(0) % iter_gws_;
232+
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;
233233
const size_t wg_size = it.get_local_range(0);
234234

235235
const size_t base = reduction_id * reduction_max_gid_;
@@ -241,14 +241,6 @@ struct ContigBooleanReduction
241241
// in group_op_
242242
group_op_(it, out_, reduction_id, inp_ + start, inp_ + end);
243243
}
244-
245-
private:
246-
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
247-
{
248-
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
249-
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
250-
return reduction_batch_id;
251-
}
252244
};
253245

254246
typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -268,6 +260,8 @@ class boolean_reduction_contig_krn;
268260
template <typename T1, typename T2, typename T3, typename T4, typename T5>
269261
class boolean_reduction_seq_contig_krn;
270262

263+
using dpctl::tensor::sycl_utils::choose_workgroup_size;
264+
271265
template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
272266
sycl::event
273267
boolean_reduction_contig_impl(sycl::queue exec_q,
@@ -288,8 +282,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
288282

289283
const sycl::device &d = exec_q.get_device();
290284
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
291-
size_t wg =
292-
4 * (*std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
285+
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
293286

294287
sycl::event red_ev;
295288
if (reduction_nelems < wg) {
@@ -433,9 +426,9 @@ struct StridedBooleanReduction
433426

434427
void operator()(sycl::nd_item<1> it) const
435428
{
436-
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
437-
const size_t reduction_id = it.get_global_id(0) / red_gws_;
438-
const size_t reduction_batch_id = get_reduction_batch_id(it);
429+
const size_t reduction_id = it.get_group(0) % iter_gws_;
430+
const size_t reduction_batch_id = it.get_group(0) / iter_gws_;
431+
439432
const size_t reduction_lid = it.get_local_id(0);
440433
const size_t wg_size = it.get_local_range(0);
441434

@@ -468,14 +461,6 @@ struct StridedBooleanReduction
468461
// in group_op_
469462
group_op_(it, out_, out_iter_offset, local_red_val);
470463
}
471-
472-
private:
473-
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
474-
{
475-
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
476-
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
477-
return reduction_batch_id;
478-
}
479464
};
480465

481466
template <typename T1,
@@ -527,8 +512,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
527512

528513
const sycl::device &d = exec_q.get_device();
529514
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
530-
size_t wg =
531-
4 * (*std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
515+
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
532516

533517
sycl::event red_ev;
534518
if (reduction_nelems < wg) {

0 commit comments

Comments
 (0)