34
34
#include " pybind11/pybind11.h"
35
35
36
36
#include " utils/offset_utils.hpp"
37
+ #include " utils/sycl_utils.hpp"
37
38
#include " utils/type_dispatch.hpp"
38
39
#include " utils/type_utils.hpp"
39
40
@@ -227,9 +228,8 @@ struct ContigBooleanReduction
227
228
228
229
void operator ()(sycl::nd_item<1 > it) const
229
230
{
230
- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
231
- const size_t reduction_id = it.get_global_id (0 ) / red_gws_;
232
- const size_t reduction_batch_id = get_reduction_batch_id (it);
231
+ const size_t reduction_id = it.get_group (0 ) % iter_gws_;
232
+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_;
233
233
const size_t wg_size = it.get_local_range (0 );
234
234
235
235
const size_t base = reduction_id * reduction_max_gid_;
@@ -241,14 +241,6 @@ struct ContigBooleanReduction
241
241
// in group_op_
242
242
group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
243
243
}
244
-
245
- private:
246
- size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
247
- {
248
- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
249
- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
250
- return reduction_batch_id;
251
- }
252
244
};
253
245
254
246
typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -268,6 +260,8 @@ class boolean_reduction_contig_krn;
268
260
template <typename T1, typename T2, typename T3, typename T4, typename T5>
269
261
class boolean_reduction_seq_contig_krn ;
270
262
263
+ using dpctl::tensor::sycl_utils::choose_workgroup_size;
264
+
271
265
template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
272
266
sycl::event
273
267
boolean_reduction_contig_impl (sycl::queue exec_q,
@@ -288,8 +282,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
288
282
289
283
const sycl::device &d = exec_q.get_device ();
290
284
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
291
- size_t wg =
292
- 4 * (*std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
285
+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
293
286
294
287
sycl::event red_ev;
295
288
if (reduction_nelems < wg) {
@@ -433,9 +426,9 @@ struct StridedBooleanReduction
433
426
434
427
void operator ()(sycl::nd_item<1 > it) const
435
428
{
436
- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
437
- const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
438
- const size_t reduction_batch_id = get_reduction_batch_id (it);
429
+ const size_t reduction_id = it.get_group (0 ) % iter_gws_;
430
+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_ ;
431
+
439
432
const size_t reduction_lid = it.get_local_id (0 );
440
433
const size_t wg_size = it.get_local_range (0 );
441
434
@@ -468,14 +461,6 @@ struct StridedBooleanReduction
468
461
// in group_op_
469
462
group_op_ (it, out_, out_iter_offset, local_red_val);
470
463
}
471
-
472
- private:
473
- size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
474
- {
475
- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
476
- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
477
- return reduction_batch_id;
478
- }
479
464
};
480
465
481
466
template <typename T1,
@@ -527,8 +512,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
527
512
528
513
const sycl::device &d = exec_q.get_device ();
529
514
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
530
- size_t wg =
531
- 4 * (*std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
515
+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
532
516
533
517
sycl::event red_ev;
534
518
if (reduction_nelems < wg) {
0 commit comments