@@ -3400,6 +3400,125 @@ struct LogSumExpOverAxis0TempsContigFactory
3400
3400
3401
3401
// Argmax and Argmin
3402
3402
3403
+ /* Sequential search reduction */
3404
+
3405
+ template <typename argT,
3406
+ typename outT,
3407
+ typename ReductionOp,
3408
+ typename IdxReductionOp,
3409
+ typename InputOutputIterIndexerT,
3410
+ typename InputRedIndexerT>
3411
+ struct SequentialSearchReduction
3412
+ {
3413
+ private:
3414
+ const argT *inp_ = nullptr ;
3415
+ outT *out_ = nullptr ;
3416
+ ReductionOp reduction_op_;
3417
+ argT identity_;
3418
+ IdxReductionOp idx_reduction_op_;
3419
+ outT idx_identity_;
3420
+ InputOutputIterIndexerT inp_out_iter_indexer_;
3421
+ InputRedIndexerT inp_reduced_dims_indexer_;
3422
+ size_t reduction_max_gid_ = 0 ;
3423
+
3424
+ public:
3425
+ SequentialSearchReduction (const argT *inp,
3426
+ outT *res,
3427
+ ReductionOp reduction_op,
3428
+ const argT &identity_val,
3429
+ IdxReductionOp idx_reduction_op,
3430
+ const outT &idx_identity_val,
3431
+ InputOutputIterIndexerT arg_res_iter_indexer,
3432
+ InputRedIndexerT arg_reduced_dims_indexer,
3433
+ size_t reduction_size)
3434
+ : inp_(inp), out_(res), reduction_op_(reduction_op),
3435
+ identity_ (identity_val), idx_reduction_op_(idx_reduction_op),
3436
+ idx_identity_(idx_identity_val),
3437
+ inp_out_iter_indexer_(arg_res_iter_indexer),
3438
+ inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3439
+ reduction_max_gid_(reduction_size)
3440
+ {
3441
+ }
3442
+
3443
+ void operator ()(sycl::id<1 > id) const
3444
+ {
3445
+
3446
+ auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
3447
+ const py::ssize_t &inp_iter_offset =
3448
+ inp_out_iter_offsets_.get_first_offset ();
3449
+ const py::ssize_t &out_iter_offset =
3450
+ inp_out_iter_offsets_.get_second_offset ();
3451
+
3452
+ argT red_val (identity_);
3453
+ outT idx_val (idx_identity_);
3454
+ for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
3455
+ const py::ssize_t inp_reduction_offset =
3456
+ inp_reduced_dims_indexer_ (m);
3457
+ const py::ssize_t inp_offset =
3458
+ inp_iter_offset + inp_reduction_offset;
3459
+
3460
+ argT val = inp_[inp_offset];
3461
+ if (val == red_val) {
3462
+ idx_val = idx_reduction_op_ (idx_val, static_cast <outT>(m));
3463
+ }
3464
+ else {
3465
+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3466
+ using dpctl::tensor::type_utils::is_complex;
3467
+ if constexpr (is_complex<argT>::value) {
3468
+ using dpctl::tensor::math_utils::less_complex;
3469
+ // less_complex always returns false for NaNs, so check
3470
+ if (less_complex<argT>(val, red_val) ||
3471
+ std::isnan (std::real (val)) ||
3472
+ std::isnan (std::imag (val)))
3473
+ {
3474
+ red_val = val;
3475
+ idx_val = static_cast <outT>(m);
3476
+ }
3477
+ }
3478
+ else if constexpr (std::is_floating_point_v<argT>) {
3479
+ if (val < red_val || std::isnan (val)) {
3480
+ red_val = val;
3481
+ idx_val = static_cast <outT>(m);
3482
+ }
3483
+ }
3484
+ else {
3485
+ if (val < red_val) {
3486
+ red_val = val;
3487
+ idx_val = static_cast <outT>(m);
3488
+ }
3489
+ }
3490
+ }
3491
+ else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3492
+ using dpctl::tensor::type_utils::is_complex;
3493
+ if constexpr (is_complex<argT>::value) {
3494
+ using dpctl::tensor::math_utils::greater_complex;
3495
+ if (greater_complex<argT>(val, red_val) ||
3496
+ std::isnan (std::real (val)) ||
3497
+ std::isnan (std::imag (val)))
3498
+ {
3499
+ red_val = val;
3500
+ idx_val = static_cast <outT>(m);
3501
+ }
3502
+ }
3503
+ else if constexpr (std::is_floating_point_v<argT>) {
3504
+ if (val > red_val || std::isnan (val)) {
3505
+ red_val = val;
3506
+ idx_val = static_cast <outT>(m);
3507
+ }
3508
+ }
3509
+ else {
3510
+ if (val > red_val) {
3511
+ red_val = val;
3512
+ idx_val = static_cast <outT>(m);
3513
+ }
3514
+ }
3515
+ }
3516
+ }
3517
+ }
3518
+ out_[out_iter_offset] = idx_val;
3519
+ }
3520
+ };
3521
+
3403
3522
/* = Search reduction using reduce_over_group*/
3404
3523
3405
3524
template <typename argT,
@@ -3798,6 +3917,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
3798
3917
py::ssize_t ,
3799
3918
const std::vector<sycl::event> &);
3800
3919
3920
+ template <typename T1,
3921
+ typename T2,
3922
+ typename T3,
3923
+ typename T4,
3924
+ typename T5,
3925
+ typename T6>
3926
+ class search_seq_strided_krn ;
3927
+
3801
3928
template <typename T1,
3802
3929
typename T2,
3803
3930
typename T3,
@@ -3819,6 +3946,14 @@ template <typename T1,
3819
3946
bool b2>
3820
3947
class custom_search_over_group_temps_strided_krn ;
3821
3948
3949
+ template <typename T1,
3950
+ typename T2,
3951
+ typename T3,
3952
+ typename T4,
3953
+ typename T5,
3954
+ typename T6>
3955
+ class search_seq_contig_krn ;
3956
+
3822
3957
template <typename T1,
3823
3958
typename T2,
3824
3959
typename T3,
@@ -4018,6 +4153,36 @@ sycl::event search_over_group_temps_strided_impl(
4018
4153
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4019
4154
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4020
4155
4156
+ if (reduction_nelems < wg) {
4157
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4158
+ cgh.depends_on (depends);
4159
+
4160
+ using InputOutputIterIndexerT =
4161
+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4162
+ using ReductionIndexerT =
4163
+ dpctl::tensor::offset_utils::StridedIndexer;
4164
+
4165
+ InputOutputIterIndexerT in_out_iter_indexer{
4166
+ iter_nd, iter_arg_offset, iter_res_offset,
4167
+ iter_shape_and_strides};
4168
+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4169
+ reduction_shape_stride};
4170
+
4171
+ cgh.parallel_for <class search_seq_strided_krn <
4172
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4173
+ ReductionIndexerT>>(
4174
+ sycl::range<1 >(iter_nelems),
4175
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4176
+ InputOutputIterIndexerT,
4177
+ ReductionIndexerT>(
4178
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4179
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4180
+ reduction_nelems));
4181
+ });
4182
+
4183
+ return comp_ev;
4184
+ }
4185
+
4021
4186
constexpr size_t preferred_reductions_per_wi = 4 ;
4022
4187
// max_max_wg prevents running out of resources on CPU
4023
4188
size_t max_wg =
@@ -4418,6 +4583,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4418
4583
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4419
4584
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4420
4585
4586
+ if (reduction_nelems < wg) {
4587
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4588
+ cgh.depends_on (depends);
4589
+
4590
+ using InputIterIndexerT =
4591
+ dpctl::tensor::offset_utils::Strided1DIndexer;
4592
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4593
+ using InputOutputIterIndexerT =
4594
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4595
+ InputIterIndexerT, NoOpIndexerT>;
4596
+ using ReductionIndexerT = NoOpIndexerT;
4597
+
4598
+ InputOutputIterIndexerT in_out_iter_indexer{
4599
+ InputIterIndexerT{0 , static_cast <py::ssize_t >(iter_nelems),
4600
+ static_cast <py::ssize_t >(reduction_nelems)},
4601
+ NoOpIndexerT{}};
4602
+ ReductionIndexerT reduction_indexer{};
4603
+
4604
+ cgh.parallel_for <class search_seq_contig_krn <
4605
+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4606
+ ReductionIndexerT>>(
4607
+ sycl::range<1 >(iter_nelems),
4608
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4609
+ InputOutputIterIndexerT,
4610
+ ReductionIndexerT>(
4611
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4612
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4613
+ reduction_nelems));
4614
+ });
4615
+
4616
+ return comp_ev;
4617
+ }
4618
+
4421
4619
constexpr size_t preferred_reductions_per_wi = 8 ;
4422
4620
// max_max_wg prevents running out of resources on CPU
4423
4621
size_t max_wg =
@@ -4800,6 +4998,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
4800
4998
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4801
4999
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
4802
5000
5001
+ if (reduction_nelems < wg) {
5002
+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
5003
+ cgh.depends_on (depends);
5004
+
5005
+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5006
+ using InputOutputIterIndexerT =
5007
+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5008
+ NoOpIndexerT, NoOpIndexerT>;
5009
+ using ReductionIndexerT =
5010
+ dpctl::tensor::offset_utils::Strided1DIndexer;
5011
+
5012
+ InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5013
+ NoOpIndexerT{}};
5014
+ ReductionIndexerT reduction_indexer{
5015
+ 0 , static_cast <py::ssize_t >(reduction_nelems),
5016
+ static_cast <py::ssize_t >(iter_nelems)};
5017
+
5018
+ using KernelName =
5019
+ class search_seq_contig_krn <argTy, resTy, ReductionOpT,
5020
+ IndexOpT, InputOutputIterIndexerT,
5021
+ ReductionIndexerT>;
5022
+
5023
+ sycl::range<1 > iter_range{iter_nelems};
5024
+
5025
+ cgh.parallel_for <KernelName>(
5026
+ iter_range,
5027
+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5028
+ InputOutputIterIndexerT,
5029
+ ReductionIndexerT>(
5030
+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
5031
+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
5032
+ reduction_nelems));
5033
+ });
5034
+
5035
+ return comp_ev;
5036
+ }
5037
+
4803
5038
constexpr size_t preferred_reductions_per_wi = 8 ;
4804
5039
// max_max_wg prevents running out of resources on CPU
4805
5040
size_t max_wg =
0 commit comments