@@ -3401,6 +3401,125 @@ struct LogSumExpOverAxis0TempsContigFactory
3401
3401
3402
3402
// Argmax and Argmin
3403
3403
3404
+ /* Sequential search reduction */
3405
+
3406
+ template <typename argT,
3407
+ typename outT,
3408
+ typename ReductionOp,
3409
+ typename IdxReductionOp,
3410
+ typename InputOutputIterIndexerT,
3411
+ typename InputRedIndexerT>
3412
+ struct SequentialSearchReduction
3413
+ {
3414
+ private:
3415
+ const argT *inp_ = nullptr ;
3416
+ outT *out_ = nullptr ;
3417
+ ReductionOp reduction_op_;
3418
+ argT identity_;
3419
+ IdxReductionOp idx_reduction_op_;
3420
+ outT idx_identity_;
3421
+ InputOutputIterIndexerT inp_out_iter_indexer_;
3422
+ InputRedIndexerT inp_reduced_dims_indexer_;
3423
+ size_t reduction_max_gid_ = 0 ;
3424
+
3425
+ public:
3426
+ SequentialSearchReduction (const argT *inp,
3427
+ outT *res,
3428
+ ReductionOp reduction_op,
3429
+ const argT &identity_val,
3430
+ IdxReductionOp idx_reduction_op,
3431
+ const outT &idx_identity_val,
3432
+ InputOutputIterIndexerT arg_res_iter_indexer,
3433
+ InputRedIndexerT arg_reduced_dims_indexer,
3434
+ size_t reduction_size)
3435
+ : inp_(inp), out_(res), reduction_op_(reduction_op),
3436
+ identity_ (identity_val), idx_reduction_op_(idx_reduction_op),
3437
+ idx_identity_(idx_identity_val),
3438
+ inp_out_iter_indexer_(arg_res_iter_indexer),
3439
+ inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440
+ reduction_max_gid_(reduction_size)
3441
+ {
3442
+ }
3443
+
3444
+ void operator ()(sycl::id<1 > id) const
3445
+ {
3446
+
3447
+ auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
3448
+ const py::ssize_t &inp_iter_offset =
3449
+ inp_out_iter_offsets_.get_first_offset ();
3450
+ const py::ssize_t &out_iter_offset =
3451
+ inp_out_iter_offsets_.get_second_offset ();
3452
+
3453
+ argT red_val (identity_);
3454
+ outT idx_val (idx_identity_);
3455
+ for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
3456
+ const py::ssize_t inp_reduction_offset =
3457
+ inp_reduced_dims_indexer_ (m);
3458
+ const py::ssize_t inp_offset =
3459
+ inp_iter_offset + inp_reduction_offset;
3460
+
3461
+ argT val = inp_[inp_offset];
3462
+ if (val == red_val) {
3463
+ idx_val = idx_reduction_op_ (idx_val, static_cast <outT>(m));
3464
+ }
3465
+ else {
3466
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467
+ using dpctl::tensor::type_utils::is_complex;
3468
+ if constexpr (is_complex<argT>::value) {
3469
+ using dpctl::tensor::math_utils::less_complex;
3470
+ // less_complex always returns false for NaNs, so check
3471
+ if (less_complex<argT>(val, red_val) ||
3472
+ std::isnan (std::real (val)) ||
3473
+ std::isnan (std::imag (val)))
3474
+ {
3475
+ red_val = val;
3476
+ idx_val = static_cast <outT>(m);
3477
+ }
3478
+ }
3479
+ else if constexpr (std::is_floating_point_v<argT>) {
3480
+ if (val < red_val || std::isnan (val)) {
3481
+ red_val = val;
3482
+ idx_val = static_cast <outT>(m);
3483
+ }
3484
+ }
3485
+ else {
3486
+ if (val < red_val) {
3487
+ red_val = val;
3488
+ idx_val = static_cast <outT>(m);
3489
+ }
3490
+ }
3491
+ }
3492
+ else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3493
+ using dpctl::tensor::type_utils::is_complex;
3494
+ if constexpr (is_complex<argT>::value) {
3495
+ using dpctl::tensor::math_utils::greater_complex;
3496
+ if (greater_complex<argT>(val, red_val) ||
3497
+ std::isnan (std::real (val)) ||
3498
+ std::isnan (std::imag (val)))
3499
+ {
3500
+ red_val = val;
3501
+ idx_val = static_cast <outT>(m);
3502
+ }
3503
+ }
3504
+ else if constexpr (std::is_floating_point_v<argT>) {
3505
+ if (val > red_val || std::isnan (val)) {
3506
+ red_val = val;
3507
+ idx_val = static_cast <outT>(m);
3508
+ }
3509
+ }
3510
+ else {
3511
+ if (val > red_val) {
3512
+ red_val = val;
3513
+ idx_val = static_cast <outT>(m);
3514
+ }
3515
+ }
3516
+ }
3517
+ }
3518
+ }
3519
+ out_[out_iter_offset] = idx_val;
3520
+ }
3521
+ };
3522
+
3404
3523
/* = Search reduction using reduce_over_group*/
3405
3524
3406
3525
template <typename argT,
@@ -3799,6 +3918,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
3799
3918
py::ssize_t ,
3800
3919
const std::vector<sycl::event> &);
3801
3920
3921
+ template <typename T1,
3922
+ typename T2,
3923
+ typename T3,
3924
+ typename T4,
3925
+ typename T5,
3926
+ typename T6>
3927
+ class search_seq_strided_krn ;
3928
+
3802
3929
template <typename T1,
3803
3930
typename T2,
3804
3931
typename T3,
@@ -3820,6 +3947,14 @@ template <typename T1,
3820
3947
bool b2>
3821
3948
class custom_search_over_group_temps_strided_krn ;
3822
3949
3950
+ template <typename T1,
3951
+ typename T2,
3952
+ typename T3,
3953
+ typename T4,
3954
+ typename T5,
3955
+ typename T6>
3956
+ class search_seq_contig_krn ;
3957
+
3823
3958
template <typename T1,
3824
3959
typename T2,
3825
3960
typename T3,
@@ -4019,6 +4154,36 @@ sycl::event search_over_group_temps_strided_impl(
4019
4154
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4020
4155
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4021
4156
4157
+ if (reduction_nelems < wg) {
4158
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4159
+ cgh.depends_on (depends);
4160
+
4161
+ using InputOutputIterIndexerT =
4162
+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4163
+ using ReductionIndexerT =
4164
+ dpctl::tensor::offset_utils::StridedIndexer;
4165
+
4166
+ InputOutputIterIndexerT in_out_iter_indexer{
4167
+ iter_nd, iter_arg_offset, iter_res_offset,
4168
+ iter_shape_and_strides};
4169
+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4170
+ reduction_shape_stride};
4171
+
4172
+ cgh.parallel_for <class search_seq_strided_krn <
4173
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4174
+ ReductionIndexerT>>(
4175
+ sycl::range<1 >(iter_nelems),
4176
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4177
+ InputOutputIterIndexerT,
4178
+ ReductionIndexerT>(
4179
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4180
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4181
+ reduction_nelems));
4182
+ });
4183
+
4184
+ return comp_ev;
4185
+ }
4186
+
4022
4187
constexpr size_t preferred_reductions_per_wi = 4 ;
4023
4188
// max_max_wg prevents running out of resources on CPU
4024
4189
size_t max_wg =
@@ -4419,6 +4584,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4419
4584
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4420
4585
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4421
4586
4587
+ if (reduction_nelems < wg) {
4588
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4589
+ cgh.depends_on (depends);
4590
+
4591
+ using InputIterIndexerT =
4592
+ dpctl::tensor::offset_utils::Strided1DIndexer;
4593
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4594
+ using InputOutputIterIndexerT =
4595
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4596
+ InputIterIndexerT, NoOpIndexerT>;
4597
+ using ReductionIndexerT = NoOpIndexerT;
4598
+
4599
+ InputOutputIterIndexerT in_out_iter_indexer{
4600
+ InputIterIndexerT{0 , static_cast <py::ssize_t >(iter_nelems),
4601
+ static_cast <py::ssize_t >(reduction_nelems)},
4602
+ NoOpIndexerT{}};
4603
+ ReductionIndexerT reduction_indexer{};
4604
+
4605
+ cgh.parallel_for <class search_seq_contig_krn <
4606
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4607
+ ReductionIndexerT>>(
4608
+ sycl::range<1 >(iter_nelems),
4609
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4610
+ InputOutputIterIndexerT,
4611
+ ReductionIndexerT>(
4612
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4613
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4614
+ reduction_nelems));
4615
+ });
4616
+
4617
+ return comp_ev;
4618
+ }
4619
+
4422
4620
constexpr size_t preferred_reductions_per_wi = 8 ;
4423
4621
// max_max_wg prevents running out of resources on CPU
4424
4622
size_t max_wg =
@@ -4801,6 +4999,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4801
4999
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4802
5000
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4803
5001
5002
+ if (reduction_nelems < wg) {
5003
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
5004
+ cgh.depends_on (depends);
5005
+
5006
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5007
+ using InputOutputIterIndexerT =
5008
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5009
+ NoOpIndexerT, NoOpIndexerT>;
5010
+ using ReductionIndexerT =
5011
+ dpctl::tensor::offset_utils::Strided1DIndexer;
5012
+
5013
+ InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5014
+ NoOpIndexerT{}};
5015
+ ReductionIndexerT reduction_indexer{
5016
+ 0 , static_cast <py::ssize_t >(reduction_nelems),
5017
+ static_cast <py::ssize_t >(iter_nelems)};
5018
+
5019
+ using KernelName =
5020
+ class search_seq_contig_krn <argTy, resTy, ReductionOpT,
5021
+ IndexOpT, InputOutputIterIndexerT,
5022
+ ReductionIndexerT>;
5023
+
5024
+ sycl::range<1 > iter_range{iter_nelems};
5025
+
5026
+ cgh.parallel_for <KernelName>(
5027
+ iter_range,
5028
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5029
+ InputOutputIterIndexerT,
5030
+ ReductionIndexerT>(
5031
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
5032
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
5033
+ reduction_nelems));
5034
+ });
5035
+
5036
+ return comp_ev;
5037
+ }
5038
+
4804
5039
constexpr size_t preferred_reductions_per_wi = 8 ;
4805
5040
// max_max_wg prevents running out of resources on CPU
4806
5041
size_t max_wg =
0 commit comments