Skip to content

Commit e5d20db

Browse files
committed
Adds SequentialSearchReduction functor to search reductions
1 parent 421b270 commit e5d20db

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3400,6 +3400,125 @@ struct LogSumExpOverAxis0TempsContigFactory
34003400

34013401
// Argmax and Argmin
34023402

3403+
/* Sequential search reduction */
3404+
3405+
template <typename argT,
3406+
typename outT,
3407+
typename ReductionOp,
3408+
typename IdxReductionOp,
3409+
typename InputOutputIterIndexerT,
3410+
typename InputRedIndexerT>
3411+
struct SequentialSearchReduction
3412+
{
3413+
private:
3414+
const argT *inp_ = nullptr;
3415+
outT *out_ = nullptr;
3416+
ReductionOp reduction_op_;
3417+
argT identity_;
3418+
IdxReductionOp idx_reduction_op_;
3419+
outT idx_identity_;
3420+
InputOutputIterIndexerT inp_out_iter_indexer_;
3421+
InputRedIndexerT inp_reduced_dims_indexer_;
3422+
size_t reduction_max_gid_ = 0;
3423+
3424+
public:
3425+
SequentialSearchReduction(const argT *inp,
3426+
outT *res,
3427+
ReductionOp reduction_op,
3428+
const argT &identity_val,
3429+
IdxReductionOp idx_reduction_op,
3430+
const outT &idx_identity_val,
3431+
InputOutputIterIndexerT arg_res_iter_indexer,
3432+
InputRedIndexerT arg_reduced_dims_indexer,
3433+
size_t reduction_size)
3434+
: inp_(inp), out_(res), reduction_op_(reduction_op),
3435+
identity_(identity_val), idx_reduction_op_(idx_reduction_op),
3436+
idx_identity_(idx_identity_val),
3437+
inp_out_iter_indexer_(arg_res_iter_indexer),
3438+
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3439+
reduction_max_gid_(reduction_size)
3440+
{
3441+
}
3442+
3443+
void operator()(sycl::id<1> id) const
3444+
{
3445+
3446+
auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
3447+
const py::ssize_t &inp_iter_offset =
3448+
inp_out_iter_offsets_.get_first_offset();
3449+
const py::ssize_t &out_iter_offset =
3450+
inp_out_iter_offsets_.get_second_offset();
3451+
3452+
argT red_val(identity_);
3453+
outT idx_val(idx_identity_);
3454+
for (size_t m = 0; m < reduction_max_gid_; ++m) {
3455+
const py::ssize_t inp_reduction_offset =
3456+
inp_reduced_dims_indexer_(m);
3457+
const py::ssize_t inp_offset =
3458+
inp_iter_offset + inp_reduction_offset;
3459+
3460+
argT val = inp_[inp_offset];
3461+
if (val == red_val) {
3462+
idx_val = idx_reduction_op_(idx_val, static_cast<outT>(m));
3463+
}
3464+
else {
3465+
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3466+
using dpctl::tensor::type_utils::is_complex;
3467+
if constexpr (is_complex<argT>::value) {
3468+
using dpctl::tensor::math_utils::less_complex;
3469+
// less_complex always returns false for NaNs, so check
3470+
if (less_complex<argT>(val, red_val) ||
3471+
std::isnan(std::real(val)) ||
3472+
std::isnan(std::imag(val)))
3473+
{
3474+
red_val = val;
3475+
idx_val = static_cast<outT>(m);
3476+
}
3477+
}
3478+
else if constexpr (std::is_floating_point_v<argT>) {
3479+
if (val < red_val || std::isnan(val)) {
3480+
red_val = val;
3481+
idx_val = static_cast<outT>(m);
3482+
}
3483+
}
3484+
else {
3485+
if (val < red_val) {
3486+
red_val = val;
3487+
idx_val = static_cast<outT>(m);
3488+
}
3489+
}
3490+
}
3491+
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3492+
using dpctl::tensor::type_utils::is_complex;
3493+
if constexpr (is_complex<argT>::value) {
3494+
using dpctl::tensor::math_utils::greater_complex;
3495+
if (greater_complex<argT>(val, red_val) ||
3496+
std::isnan(std::real(val)) ||
3497+
std::isnan(std::imag(val)))
3498+
{
3499+
red_val = val;
3500+
idx_val = static_cast<outT>(m);
3501+
}
3502+
}
3503+
else if constexpr (std::is_floating_point_v<argT>) {
3504+
if (val > red_val || std::isnan(val)) {
3505+
red_val = val;
3506+
idx_val = static_cast<outT>(m);
3507+
}
3508+
}
3509+
else {
3510+
if (val > red_val) {
3511+
red_val = val;
3512+
idx_val = static_cast<outT>(m);
3513+
}
3514+
}
3515+
}
3516+
}
3517+
}
3518+
out_[out_iter_offset] = idx_val;
3519+
}
3520+
};
3521+
34033522
/* = Search reduction using reduce_over_group*/
34043523

34053524
template <typename argT,
@@ -3798,6 +3917,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
37983917
py::ssize_t,
37993918
const std::vector<sycl::event> &);
38003919

3920+
template <typename T1,
3921+
typename T2,
3922+
typename T3,
3923+
typename T4,
3924+
typename T5,
3925+
typename T6>
3926+
class search_seq_strided_krn;
3927+
38013928
template <typename T1,
38023929
typename T2,
38033930
typename T3,
@@ -3819,6 +3946,14 @@ template <typename T1,
38193946
bool b2>
38203947
class custom_search_over_group_temps_strided_krn;
38213948

3949+
template <typename T1,
3950+
typename T2,
3951+
typename T3,
3952+
typename T4,
3953+
typename T5,
3954+
typename T6>
3955+
class search_seq_contig_krn;
3956+
38223957
template <typename T1,
38233958
typename T2,
38243959
typename T3,
@@ -4018,6 +4153,36 @@ sycl::event search_over_group_temps_strided_impl(
40184153
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
40194154
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
40204155

4156+
if (reduction_nelems < wg) {
4157+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4158+
cgh.depends_on(depends);
4159+
4160+
using InputOutputIterIndexerT =
4161+
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4162+
using ReductionIndexerT =
4163+
dpctl::tensor::offset_utils::StridedIndexer;
4164+
4165+
InputOutputIterIndexerT in_out_iter_indexer{
4166+
iter_nd, iter_arg_offset, iter_res_offset,
4167+
iter_shape_and_strides};
4168+
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4169+
reduction_shape_stride};
4170+
4171+
cgh.parallel_for<class search_seq_strided_krn<
4172+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4173+
ReductionIndexerT>>(
4174+
sycl::range<1>(iter_nelems),
4175+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4176+
InputOutputIterIndexerT,
4177+
ReductionIndexerT>(
4178+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4179+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4180+
reduction_nelems));
4181+
});
4182+
4183+
return comp_ev;
4184+
}
4185+
40214186
constexpr size_t preferred_reductions_per_wi = 4;
40224187
// max_max_wg prevents running out of resources on CPU
40234188
size_t max_wg =
@@ -4418,6 +4583,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
44184583
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
44194584
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
44204585

4586+
if (reduction_nelems < wg) {
4587+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
4588+
cgh.depends_on(depends);
4589+
4590+
using InputIterIndexerT =
4591+
dpctl::tensor::offset_utils::Strided1DIndexer;
4592+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4593+
using InputOutputIterIndexerT =
4594+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4595+
InputIterIndexerT, NoOpIndexerT>;
4596+
using ReductionIndexerT = NoOpIndexerT;
4597+
4598+
InputOutputIterIndexerT in_out_iter_indexer{
4599+
InputIterIndexerT{0, static_cast<py::ssize_t>(iter_nelems),
4600+
static_cast<py::ssize_t>(reduction_nelems)},
4601+
NoOpIndexerT{}};
4602+
ReductionIndexerT reduction_indexer{};
4603+
4604+
cgh.parallel_for<class search_seq_contig_krn<
4605+
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4606+
ReductionIndexerT>>(
4607+
sycl::range<1>(iter_nelems),
4608+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4609+
InputOutputIterIndexerT,
4610+
ReductionIndexerT>(
4611+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
4612+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
4613+
reduction_nelems));
4614+
});
4615+
4616+
return comp_ev;
4617+
}
4618+
44214619
constexpr size_t preferred_reductions_per_wi = 8;
44224620
// max_max_wg prevents running out of resources on CPU
44234621
size_t max_wg =
@@ -4800,6 +4998,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
48004998
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
48014999
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
48025000

5001+
if (reduction_nelems < wg) {
5002+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
5003+
cgh.depends_on(depends);
5004+
5005+
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5006+
using InputOutputIterIndexerT =
5007+
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5008+
NoOpIndexerT, NoOpIndexerT>;
5009+
using ReductionIndexerT =
5010+
dpctl::tensor::offset_utils::Strided1DIndexer;
5011+
5012+
InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5013+
NoOpIndexerT{}};
5014+
ReductionIndexerT reduction_indexer{
5015+
0, static_cast<py::ssize_t>(reduction_nelems),
5016+
static_cast<py::ssize_t>(iter_nelems)};
5017+
5018+
using KernelName =
5019+
class search_seq_contig_krn<argTy, resTy, ReductionOpT,
5020+
IndexOpT, InputOutputIterIndexerT,
5021+
ReductionIndexerT>;
5022+
5023+
sycl::range<1> iter_range{iter_nelems};
5024+
5025+
cgh.parallel_for<KernelName>(
5026+
iter_range,
5027+
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5028+
InputOutputIterIndexerT,
5029+
ReductionIndexerT>(
5030+
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
5031+
idx_identity_val, in_out_iter_indexer, reduction_indexer,
5032+
reduction_nelems));
5033+
});
5034+
5035+
return comp_ev;
5036+
}
5037+
48035038
constexpr size_t preferred_reductions_per_wi = 8;
48045039
// max_max_wg prevents running out of resources on CPU
48055040
size_t max_wg =

0 commit comments

Comments
 (0)