Skip to content

Commit 5709f99

Browse files
committed
Adds SequentialSearchReduction functor to search reductions
1 parent 097ecf5 commit 5709f99

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,6 +3401,125 @@ struct LogSumExpOverAxis0TempsContigFactory
34013401

34023402
// Argmax and Argmin
34033403

3404+
/* Sequential search reduction */
3405+
3406+
template <typename argT,
3407+
typename outT,
3408+
typename ReductionOp,
3409+
typename IdxReductionOp,
3410+
typename InputOutputIterIndexerT,
3411+
typename InputRedIndexerT>
3412+
struct SequentialSearchReduction
3413+
{
3414+
private:
3415+
const argT *inp_ = nullptr;
3416+
outT *out_ = nullptr;
3417+
ReductionOp reduction_op_;
3418+
argT identity_;
3419+
IdxReductionOp idx_reduction_op_;
3420+
outT idx_identity_;
3421+
InputOutputIterIndexerT inp_out_iter_indexer_;
3422+
InputRedIndexerT inp_reduced_dims_indexer_;
3423+
size_t reduction_max_gid_ = 0;
3424+
3425+
public:
3426+
SequentialSearchReduction(const argT *inp,
3427+
outT *res,
3428+
ReductionOp reduction_op,
3429+
const argT &identity_val,
3430+
IdxReductionOp idx_reduction_op,
3431+
const outT &idx_identity_val,
3432+
InputOutputIterIndexerT arg_res_iter_indexer,
3433+
InputRedIndexerT arg_reduced_dims_indexer,
3434+
size_t reduction_size)
3435+
: inp_(inp), out_(res), reduction_op_(reduction_op),
3436+
identity_(identity_val), idx_reduction_op_(idx_reduction_op),
3437+
idx_identity_(idx_identity_val),
3438+
inp_out_iter_indexer_(arg_res_iter_indexer),
3439+
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440+
reduction_max_gid_(reduction_size)
3441+
{
3442+
}
3443+
3444+
void operator()(sycl::id<1> id) const
3445+
{
3446+
3447+
auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
3448+
const py::ssize_t &inp_iter_offset =
3449+
inp_out_iter_offsets_.get_first_offset();
3450+
const py::ssize_t &out_iter_offset =
3451+
inp_out_iter_offsets_.get_second_offset();
3452+
3453+
argT red_val(identity_);
3454+
outT idx_val(idx_identity_);
3455+
for (size_t m = 0; m < reduction_max_gid_; ++m) {
3456+
const py::ssize_t inp_reduction_offset =
3457+
inp_reduced_dims_indexer_(m);
3458+
const py::ssize_t inp_offset =
3459+
inp_iter_offset + inp_reduction_offset;
3460+
3461+
argT val = inp_[inp_offset];
3462+
if (val == red_val) {
3463+
idx_val = idx_reduction_op_(idx_val, static_cast<outT>(m));
3464+
}
3465+
else {
3466+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467+
using dpctl::tensor::type_utils::is_complex;
3468+
if constexpr (is_complex<argT>::value) {
3469+
using dpctl::tensor::math_utils::less_complex;
3470+
// less_complex always returns false for NaNs, so check
3471+
if (less_complex<argT>(val, red_val) ||
3472+
std::isnan(std::real(val)) ||
3473+
std::isnan(std::imag(val)))
3474+
{
3475+
red_val = val;
3476+
idx_val = static_cast<outT>(m);
3477+
}
3478+
}
3479+
else if constexpr (std::is_floating_point_v<argT>) {
3480+
if (val < red_val || std::isnan(val)) {
3481+
red_val = val;
3482+
idx_val = static_cast<outT>(m);
3483+
}
3484+
}
3485+
else {
3486+
if (val < red_val) {
3487+
red_val = val;
3488+
idx_val = static_cast<outT>(m);
3489+
}
3490+
}
3491+
}
3492+
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3493+
using dpctl::tensor::type_utils::is_complex;
3494+
if constexpr (is_complex<argT>::value) {
3495+
using dpctl::tensor::math_utils::greater_complex;
3496+
if (greater_complex<argT>(val, red_val) ||
3497+
std::isnan(std::real(val)) ||
3498+
std::isnan(std::imag(val)))
3499+
{
3500+
red_val = val;
3501+
idx_val = static_cast<outT>(m);
3502+
}
3503+
}
3504+
else if constexpr (std::is_floating_point_v<argT>) {
3505+
if (val > red_val || std::isnan(val)) {
3506+
red_val = val;
3507+
idx_val = static_cast<outT>(m);
3508+
}
3509+
}
3510+
else {
3511+
if (val > red_val) {
3512+
red_val = val;
3513+
idx_val = static_cast<outT>(m);
3514+
}
3515+
}
3516+
}
3517+
}
3518+
}
3519+
out_[out_iter_offset] = idx_val;
3520+
}
3521+
};
3522+
34043523
/* = Search reduction using reduce_over_group*/
34053524

34063525
template <typename argT,
@@ -3799,6 +3918,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
37993918
py::ssize_t,
38003919
const std::vector<sycl::event> &);
38013920

3921+
template <typename T1,
3922+
typename T2,
3923+
typename T3,
3924+
typename T4,
3925+
typename T5,
3926+
typename T6>
3927+
class search_seq_strided_krn;
3928+
38023929
template <typename T1,
38033930
typename T2,
38043931
typename T3,
@@ -3820,6 +3947,14 @@ template <typename T1,
38203947
bool b2>
38213948
class custom_search_over_group_temps_strided_krn;
38223949

3950+
template <typename T1,
3951+
typename T2,
3952+
typename T3,
3953+
typename T4,
3954+
typename T5,
3955+
typename T6>
3956+
class search_seq_contig_krn;
3957+
38233958
template <typename T1,
38243959
typename T2,
38253960
typename T3,
@@ -4019,6 +4154,36 @@ sycl::event search_over_group_temps_strided_impl(
40194154
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
40204155
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
40214156

4157+
if (reduction_nelems < wg) {
4158+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4159+
cgh.depends_on(depends);
4160+
4161+
using InputOutputIterIndexerT =
4162+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4163+
using ReductionIndexerT =
4164+
dpctl::tensor::offset_utils::StridedIndexer;
4165+
4166+
InputOutputIterIndexerT in_out_iter_indexer{
4167+
iter_nd, iter_arg_offset, iter_res_offset,
4168+
iter_shape_and_strides};
4169+
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4170+
reduction_shape_stride};
4171+
4172+
cgh.parallel_for<class search_seq_strided_krn<
4173+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4174+
ReductionIndexerT>>(
4175+
sycl::range<1>(iter_nelems),
4176+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4177+
InputOutputIterIndexerT,
4178+
ReductionIndexerT>(
4179+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4180+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4181+
reduction_nelems));
4182+
});
4183+
4184+
return comp_ev;
4185+
}
4186+
40224187
constexpr size_t preferred_reductions_per_wi = 4;
40234188
// max_max_wg prevents running out of resources on CPU
40244189
size_t max_wg =
@@ -4419,6 +4584,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
44194584
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
44204585
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
44214586

4587+
if (reduction_nelems < wg) {
4588+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4589+
cgh.depends_on(depends);
4590+
4591+
using InputIterIndexerT =
4592+
dpctl::tensor::offset_utils::Strided1DIndexer;
4593+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4594+
using InputOutputIterIndexerT =
4595+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4596+
InputIterIndexerT, NoOpIndexerT>;
4597+
using ReductionIndexerT = NoOpIndexerT;
4598+
4599+
InputOutputIterIndexerT in_out_iter_indexer{
4600+
InputIterIndexerT{0, static_cast<py::ssize_t>(iter_nelems),
4601+
static_cast<py::ssize_t>(reduction_nelems)},
4602+
NoOpIndexerT{}};
4603+
ReductionIndexerT reduction_indexer{};
4604+
4605+
cgh.parallel_for<class search_seq_contig_krn<
4606+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4607+
ReductionIndexerT>>(
4608+
sycl::range<1>(iter_nelems),
4609+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4610+
InputOutputIterIndexerT,
4611+
ReductionIndexerT>(
4612+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4613+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4614+
reduction_nelems));
4615+
});
4616+
4617+
return comp_ev;
4618+
}
4619+
44224620
constexpr size_t preferred_reductions_per_wi = 8;
44234621
// max_max_wg prevents running out of resources on CPU
44244622
size_t max_wg =
@@ -4801,6 +4999,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
48014999
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
48025000
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
48035001

5002+
if (reduction_nelems < wg) {
5003+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
5004+
cgh.depends_on(depends);
5005+
5006+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5007+
using InputOutputIterIndexerT =
5008+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5009+
NoOpIndexerT, NoOpIndexerT>;
5010+
using ReductionIndexerT =
5011+
dpctl::tensor::offset_utils::Strided1DIndexer;
5012+
5013+
InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5014+
NoOpIndexerT{}};
5015+
ReductionIndexerT reduction_indexer{
5016+
0, static_cast<py::ssize_t>(reduction_nelems),
5017+
static_cast<py::ssize_t>(iter_nelems)};
5018+
5019+
using KernelName =
5020+
class search_seq_contig_krn<argTy, resTy, ReductionOpT,
5021+
IndexOpT, InputOutputIterIndexerT,
5022+
ReductionIndexerT>;
5023+
5024+
sycl::range<1> iter_range{iter_nelems};
5025+
5026+
cgh.parallel_for<KernelName>(
5027+
iter_range,
5028+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5029+
InputOutputIterIndexerT,
5030+
ReductionIndexerT>(
5031+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
5032+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
5033+
reduction_nelems));
5034+
});
5035+
5036+
return comp_ev;
5037+
}
5038+
48045039
constexpr size_t preferred_reductions_per_wi = 8;
48055040
// max_max_wg prevents running out of resources on CPU
48065041
size_t max_wg =

0 commit comments

Comments
 (0)