Skip to content

Commit 8f469a8

Browse files
committed
Implements boolean reduction kernel for axis 0
- Aligns with similar changes to sum
1 parent 1e85b1e commit 8f469a8

File tree

3 files changed

+227
-65
lines changed

3 files changed

+227
-65
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

Lines changed: 124 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,15 @@ using dpctl::tensor::sycl_utils::choose_workgroup_size;
264264

265265
template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
266266
sycl::event
267-
boolean_reduction_contig_impl(sycl::queue exec_q,
268-
size_t iter_nelems,
269-
size_t reduction_nelems,
270-
const char *arg_cp,
271-
char *res_cp,
272-
py::ssize_t iter_arg_offset,
273-
py::ssize_t iter_res_offset,
274-
py::ssize_t red_arg_offset,
275-
const std::vector<sycl::event> &depends)
267+
boolean_reduction_axis1_contig_impl(sycl::queue exec_q,
268+
size_t iter_nelems,
269+
size_t reduction_nelems,
270+
const char *arg_cp,
271+
char *res_cp,
272+
py::ssize_t iter_arg_offset,
273+
py::ssize_t iter_res_offset,
274+
py::ssize_t red_arg_offset,
275+
const std::vector<sycl::event> &depends)
276276
{
277277
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
278278
iter_arg_offset + red_arg_offset;
@@ -315,18 +315,8 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
315315
});
316316
}
317317
else {
318-
sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) {
319-
using IndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
320-
321-
IndexerT res_indexer{};
322-
323-
cgh.depends_on(depends);
324-
325-
cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) {
326-
auto res_offset = res_indexer(id[0]);
327-
res_tp[res_offset] = identity_val;
328-
});
329-
});
318+
sycl::event init_ev = exec_q.fill<resTy>(res_tp, resTy(identity_val),
319+
iter_nelems, depends);
330320
red_ev = exec_q.submit([&](sycl::handler &cgh) {
331321
cgh.depends_on(init_ev);
332322

@@ -356,7 +346,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
356346
return red_ev;
357347
}
358348

359-
template <typename fnT, typename srcTy> struct AllContigFactory
349+
template <typename fnT, typename srcTy> struct AllAxis1ContigFactory
360350
{
361351
fnT get() const
362352
{
@@ -365,12 +355,12 @@ template <typename fnT, typename srcTy> struct AllContigFactory
365355
using GroupOpT =
366356
all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
367357

368-
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
358+
return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl<
369359
srcTy, resTy, RedOpT, GroupOpT>;
370360
}
371361
};
372362

373-
template <typename fnT, typename srcTy> struct AnyContigFactory
363+
template <typename fnT, typename srcTy> struct AnyAxis1ContigFactory
374364
{
375365
fnT get() const
376366
{
@@ -379,7 +369,7 @@ template <typename fnT, typename srcTy> struct AnyContigFactory
379369
using GroupOpT =
380370
any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
381371

382-
return dpctl::tensor::kernels::boolean_reduction_contig_impl<
372+
return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl<
383373
srcTy, resTy, RedOpT, GroupOpT>;
384374
}
385375
};
@@ -463,6 +453,113 @@ struct StridedBooleanReduction
463453
}
464454
};
465455

456+
template <typename T1,
457+
typename T2,
458+
typename T3,
459+
typename T4,
460+
typename T5,
461+
typename T6>
462+
class boolean_reduction_axis0_contig_krn;
463+
464+
template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
465+
sycl::event
466+
boolean_reduction_axis0_contig_impl(sycl::queue exec_q,
467+
size_t iter_nelems,
468+
size_t reduction_nelems,
469+
const char *arg_cp,
470+
char *res_cp,
471+
py::ssize_t iter_arg_offset,
472+
py::ssize_t iter_res_offset,
473+
py::ssize_t red_arg_offset,
474+
const std::vector<sycl::event> &depends)
475+
{
476+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
477+
iter_arg_offset + red_arg_offset;
478+
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;
479+
480+
constexpr resTy identity_val = sycl::known_identity<RedOpT, resTy>::value;
481+
482+
const sycl::device &d = exec_q.get_device();
483+
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
484+
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
485+
486+
{
487+
sycl::event init_ev = exec_q.fill<resTy>(res_tp, resTy(identity_val),
488+
iter_nelems, depends);
489+
sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) {
490+
cgh.depends_on(init_ev);
491+
492+
constexpr std::uint8_t dim = 1;
493+
494+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
495+
using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
496+
using InputOutputIterIndexerT =
497+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
498+
NoOpIndexerT, NoOpIndexerT>;
499+
using ReductionIndexerT = ColsIndexerT;
500+
501+
NoOpIndexerT columns_indexer{};
502+
NoOpIndexerT result_indexer{};
503+
InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
504+
result_indexer};
505+
ReductionIndexerT reduction_indexer{
506+
0, static_cast<py::ssize_t>(reduction_nelems),
507+
static_cast<py::ssize_t>(iter_nelems)};
508+
509+
constexpr size_t preferred_reductions_per_wi = 4;
510+
size_t reductions_per_wi =
511+
(reduction_nelems < preferred_reductions_per_wi * wg)
512+
? ((reduction_nelems + wg - 1) / wg)
513+
: preferred_reductions_per_wi;
514+
515+
size_t reduction_groups =
516+
(reduction_nelems + reductions_per_wi * wg - 1) /
517+
(reductions_per_wi * wg);
518+
519+
auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
520+
auto lws = sycl::range<dim>{wg};
521+
522+
cgh.parallel_for<class boolean_reduction_axis0_contig_krn<
523+
argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
524+
ReductionIndexerT>>(
525+
sycl::nd_range<dim>(gws, lws),
526+
StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
527+
InputOutputIterIndexerT,
528+
ReductionIndexerT>(
529+
arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val,
530+
in_out_iter_indexer, reduction_indexer, reduction_nelems,
531+
iter_nelems, reductions_per_wi));
532+
});
533+
return red_ev;
534+
}
535+
}
536+
537+
template <typename fnT, typename srcTy> struct AllAxis0ContigFactory
538+
{
539+
fnT get() const
540+
{
541+
using resTy = std::int32_t;
542+
using RedOpT = sycl::logical_and<resTy>;
543+
using GroupOpT = all_reduce_wg_strided<resTy>;
544+
545+
return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl<
546+
srcTy, resTy, RedOpT, GroupOpT>;
547+
}
548+
};
549+
550+
template <typename fnT, typename srcTy> struct AnyAxis0ContigFactory
551+
{
552+
fnT get() const
553+
{
554+
using resTy = std::int32_t;
555+
using RedOpT = sycl::logical_or<resTy>;
556+
using GroupOpT = any_reduce_wg_strided<resTy>;
557+
558+
return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl<
559+
srcTy, resTy, RedOpT, GroupOpT>;
560+
}
561+
};
562+
466563
template <typename T1,
467564
typename T2,
468565
typename T3,
@@ -542,7 +639,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
542639
});
543640
}
544641
else {
545-
sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) {
642+
sycl::event init_ev = exec_q.submit([&](sycl::handler &cgh) {
546643
using IndexerT =
547644
dpctl::tensor::offset_utils::UnpackedStridedIndexer;
548645

@@ -560,7 +657,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
560657
});
561658
});
562659
red_ev = exec_q.submit([&](sycl::handler &cgh) {
563-
cgh.depends_on(res_init_ev);
660+
cgh.depends_on(init_ev);
564661

565662
constexpr std::uint8_t dim = 1;
566663

dpctl/tensor/libtensor/source/boolean_reductions.cpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ using dpctl::tensor::kernels::boolean_reduction_strided_impl_fn_ptr;
5858
static boolean_reduction_strided_impl_fn_ptr
5959
all_reduction_strided_dispatch_vector[td_ns::num_types];
6060
static boolean_reduction_contig_impl_fn_ptr
61-
all_reduction_contig_dispatch_vector[td_ns::num_types];
61+
all_reduction_axis1_contig_dispatch_vector[td_ns::num_types];
62+
static boolean_reduction_contig_impl_fn_ptr
63+
all_reduction_axis0_contig_dispatch_vector[td_ns::num_types];
6264

6365
void populate_all_dispatch_vectors(void)
6466
{
@@ -74,11 +76,19 @@ void populate_all_dispatch_vectors(void)
7476

7577
using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr;
7678

77-
using dpctl::tensor::kernels::AllContigFactory;
79+
using dpctl::tensor::kernels::AllAxis1ContigFactory;
7880
DispatchVectorBuilder<boolean_reduction_contig_impl_fn_ptr,
79-
AllContigFactory, td_ns::num_types>
81+
AllAxis1ContigFactory, td_ns::num_types>
8082
all_dvb2;
81-
all_dvb2.populate_dispatch_vector(all_reduction_contig_dispatch_vector);
83+
all_dvb2.populate_dispatch_vector(
84+
all_reduction_axis1_contig_dispatch_vector);
85+
86+
using dpctl::tensor::kernels::AllAxis0ContigFactory;
87+
DispatchVectorBuilder<boolean_reduction_contig_impl_fn_ptr,
88+
AllAxis0ContigFactory, td_ns::num_types>
89+
all_dvb3;
90+
all_dvb3.populate_dispatch_vector(
91+
all_reduction_axis0_contig_dispatch_vector);
8292
};
8393

8494
} // namespace impl
@@ -91,7 +101,9 @@ static boolean_reduction_strided_impl_fn_ptr
91101
any_reduction_strided_dispatch_vector[td_ns::num_types];
92102
using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr;
93103
static boolean_reduction_contig_impl_fn_ptr
94-
any_reduction_contig_dispatch_vector[td_ns::num_types];
104+
any_reduction_axis1_contig_dispatch_vector[td_ns::num_types];
105+
static boolean_reduction_contig_impl_fn_ptr
106+
any_reduction_axis0_contig_dispatch_vector[td_ns::num_types];
95107

96108
void populate_any_dispatch_vectors(void)
97109
{
@@ -107,11 +119,19 @@ void populate_any_dispatch_vectors(void)
107119

108120
using dpctl::tensor::kernels::boolean_reduction_contig_impl_fn_ptr;
109121

110-
using dpctl::tensor::kernels::AnyContigFactory;
122+
using dpctl::tensor::kernels::AnyAxis1ContigFactory;
111123
DispatchVectorBuilder<boolean_reduction_contig_impl_fn_ptr,
112-
AnyContigFactory, td_ns::num_types>
124+
AnyAxis1ContigFactory, td_ns::num_types>
113125
any_dvb2;
114-
any_dvb2.populate_dispatch_vector(any_reduction_contig_dispatch_vector);
126+
any_dvb2.populate_dispatch_vector(
127+
any_reduction_axis1_contig_dispatch_vector);
128+
129+
using dpctl::tensor::kernels::AnyAxis0ContigFactory;
130+
DispatchVectorBuilder<boolean_reduction_contig_impl_fn_ptr,
131+
AnyAxis0ContigFactory, td_ns::num_types>
132+
any_dvb3;
133+
any_dvb3.populate_dispatch_vector(
134+
any_reduction_axis0_contig_dispatch_vector);
115135
};
116136

117137
} // namespace impl
@@ -124,16 +144,18 @@ void init_boolean_reduction_functions(py::module_ m)
124144
// ALL
125145
{
126146
impl::populate_all_dispatch_vectors();
127-
using impl::all_reduction_contig_dispatch_vector;
147+
using impl::all_reduction_axis0_contig_dispatch_vector;
148+
using impl::all_reduction_axis1_contig_dispatch_vector;
128149
using impl::all_reduction_strided_dispatch_vector;
129150

130151
auto all_pyapi = [&](arrayT src, int trailing_dims_to_reduce,
131152
arrayT dst, sycl::queue exec_q,
132153
const event_vecT &depends = {}) {
133-
return py_boolean_reduction(src, trailing_dims_to_reduce, dst,
134-
exec_q, depends,
135-
all_reduction_contig_dispatch_vector,
136-
all_reduction_strided_dispatch_vector);
154+
return py_boolean_reduction(
155+
src, trailing_dims_to_reduce, dst, exec_q, depends,
156+
all_reduction_axis1_contig_dispatch_vector,
157+
all_reduction_axis0_contig_dispatch_vector,
158+
all_reduction_strided_dispatch_vector);
137159
};
138160
m.def("_all", all_pyapi, "", py::arg("src"),
139161
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
@@ -143,16 +165,18 @@ void init_boolean_reduction_functions(py::module_ m)
143165
// ANY
144166
{
145167
impl::populate_any_dispatch_vectors();
146-
using impl::any_reduction_contig_dispatch_vector;
168+
using impl::any_reduction_axis0_contig_dispatch_vector;
169+
using impl::any_reduction_axis1_contig_dispatch_vector;
147170
using impl::any_reduction_strided_dispatch_vector;
148171

149172
auto any_pyapi = [&](arrayT src, int trailing_dims_to_reduce,
150173
arrayT dst, sycl::queue exec_q,
151174
const event_vecT &depends = {}) {
152-
return py_boolean_reduction(src, trailing_dims_to_reduce, dst,
153-
exec_q, depends,
154-
any_reduction_contig_dispatch_vector,
155-
any_reduction_strided_dispatch_vector);
175+
return py_boolean_reduction(
176+
src, trailing_dims_to_reduce, dst, exec_q, depends,
177+
any_reduction_axis1_contig_dispatch_vector,
178+
any_reduction_axis0_contig_dispatch_vector,
179+
any_reduction_strided_dispatch_vector);
156180
};
157181
m.def("_any", any_pyapi, "", py::arg("src"),
158182
py::arg("trailing_dims_to_reduce"), py::arg("dst"),

0 commit comments

Comments
 (0)