@@ -55,15 +55,12 @@ template <typename T> struct boolean_predicate
55
55
}
56
56
};
57
57
58
- template <typename inpT,
59
- typename outT,
60
- typename PredicateT,
61
- std::uint8_t wg_dim = 2 >
58
+ template <typename inpT, typename outT, typename PredicateT>
62
59
struct all_reduce_wg_contig
63
60
{
64
- void operator ()(sycl::nd_item<wg_dim > &ndit,
61
+ void operator ()(sycl::nd_item<1 > &ndit,
65
62
outT *out,
66
- size_t &out_idx,
63
+ const size_t &out_idx,
67
64
const inpT *start,
68
65
const inpT *end) const
69
66
{
@@ -82,15 +79,12 @@ struct all_reduce_wg_contig
82
79
}
83
80
};
84
81
85
- template <typename inpT,
86
- typename outT,
87
- typename PredicateT,
88
- std::uint8_t wg_dim = 2 >
82
+ template <typename inpT, typename outT, typename PredicateT>
89
83
struct any_reduce_wg_contig
90
84
{
91
- void operator ()(sycl::nd_item<wg_dim > &ndit,
85
+ void operator ()(sycl::nd_item<1 > &ndit,
92
86
outT *out,
93
- size_t &out_idx,
87
+ const size_t &out_idx,
94
88
const inpT *start,
95
89
const inpT *end) const
96
90
{
@@ -109,9 +103,9 @@ struct any_reduce_wg_contig
109
103
}
110
104
};
111
105
112
- template <typename T, std:: uint8_t wg_dim = 2 > struct all_reduce_wg_strided
106
+ template <typename T> struct all_reduce_wg_strided
113
107
{
114
- void operator ()(sycl::nd_item<wg_dim > &ndit,
108
+ void operator ()(sycl::nd_item<1 > &ndit,
115
109
T *out,
116
110
const size_t &out_idx,
117
111
const T &local_val) const
@@ -129,9 +123,9 @@ template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
129
123
}
130
124
};
131
125
132
- template <typename T, std:: uint8_t wg_dim = 2 > struct any_reduce_wg_strided
126
+ template <typename T> struct any_reduce_wg_strided
133
127
{
134
- void operator ()(sycl::nd_item<wg_dim > &ndit,
128
+ void operator ()(sycl::nd_item<1 > &ndit,
135
129
T *out,
136
130
const size_t &out_idx,
137
131
const T &local_val) const
@@ -215,35 +209,46 @@ struct ContigBooleanReduction
215
209
outT *out_ = nullptr ;
216
210
GroupOp group_op_;
217
211
size_t reduction_max_gid_ = 0 ;
212
+ size_t iter_gws_ = 1 ;
218
213
size_t reductions_per_wi = 16 ;
219
214
220
215
public:
221
216
ContigBooleanReduction (const argT *inp,
222
217
outT *res,
223
218
GroupOp group_op,
224
219
size_t reduction_size,
220
+ size_t iteration_size,
225
221
size_t reduction_size_per_wi)
226
222
: inp_(inp), out_(res), group_op_(group_op),
227
- reduction_max_gid_ (reduction_size),
223
+ reduction_max_gid_ (reduction_size), iter_gws_(iteration_size),
228
224
reductions_per_wi(reduction_size_per_wi)
229
225
{
230
226
}
231
227
232
- void operator ()(sycl::nd_item<2 > it) const
228
+ void operator ()(sycl::nd_item<1 > it) const
233
229
{
234
-
235
- size_t reduction_id = it.get_group (0 );
236
- size_t reduction_batch_id = it.get_group (1 );
237
- size_t wg_size = it.get_local_range (1 );
238
-
239
- size_t base = reduction_id * reduction_max_gid_;
240
- size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
241
- size_t end = std::min ((start + (reductions_per_wi * wg_size)),
242
- base + reduction_max_gid_);
230
+ const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
231
+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_;
232
+ const size_t reduction_batch_id = get_reduction_batch_id (it);
233
+ const size_t wg_size = it.get_local_range (0 );
234
+
235
+ const size_t base = reduction_id * reduction_max_gid_;
236
+ const size_t start =
237
+ base + reduction_batch_id * wg_size * reductions_per_wi;
238
+ const size_t end = std::min ((start + (reductions_per_wi * wg_size)),
239
+ base + reduction_max_gid_);
243
240
// reduction and atomic operations are performed
244
241
// in group_op_
245
242
group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
246
243
}
244
+
245
+ private:
246
+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
247
+ {
248
+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
249
+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
250
+ return reduction_batch_id;
251
+ }
247
252
};
248
253
249
254
typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -332,7 +337,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
332
337
red_ev = exec_q.submit ([&](sycl::handler &cgh) {
333
338
cgh.depends_on (init_ev);
334
339
335
- constexpr std::uint8_t group_dim = 2 ;
340
+ constexpr std::uint8_t dim = 1 ;
336
341
337
342
constexpr size_t preferred_reductions_per_wi = 4 ;
338
343
size_t reductions_per_wi =
@@ -344,15 +349,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
344
349
(reduction_nelems + reductions_per_wi * wg - 1 ) /
345
350
(reductions_per_wi * wg);
346
351
347
- auto gws =
348
- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
349
- auto lws = sycl::range<group_dim>{1 , wg};
352
+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
353
+ auto lws = sycl::range<dim>{wg};
350
354
351
355
cgh.parallel_for <
352
356
class boolean_reduction_contig_krn <argTy, resTy, GroupOpT>>(
353
- sycl::nd_range<group_dim >(gws, lws),
357
+ sycl::nd_range<dim >(gws, lws),
354
358
ContigBooleanReduction<argTy, resTy, GroupOpT>(
355
- arg_tp, res_tp, GroupOpT (), reduction_nelems,
359
+ arg_tp, res_tp, GroupOpT (), reduction_nelems, iter_nelems,
356
360
reductions_per_wi));
357
361
});
358
362
}
@@ -404,6 +408,7 @@ struct StridedBooleanReduction
404
408
InputOutputIterIndexerT inp_out_iter_indexer_;
405
409
InputRedIndexerT inp_reduced_dims_indexer_;
406
410
size_t reduction_max_gid_ = 0 ;
411
+ size_t iter_gws_ = 1 ;
407
412
size_t reductions_per_wi = 16 ;
408
413
409
414
public:
@@ -415,23 +420,24 @@ struct StridedBooleanReduction
415
420
InputOutputIterIndexerT arg_res_iter_indexer,
416
421
InputRedIndexerT arg_reduced_dims_indexer,
417
422
size_t reduction_size,
423
+ size_t iteration_size,
418
424
size_t reduction_size_per_wi)
419
425
: inp_(inp), out_(res), reduction_op_(reduction_op),
420
426
group_op_ (group_op), identity_(identity_val),
421
427
inp_out_iter_indexer_(arg_res_iter_indexer),
422
428
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
423
- reduction_max_gid_(reduction_size),
429
+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
424
430
reductions_per_wi(reduction_size_per_wi)
425
431
{
426
432
}
427
433
428
- void operator ()(sycl::nd_item<2 > it) const
434
+ void operator ()(sycl::nd_item<1 > it) const
429
435
{
430
-
431
- size_t reduction_id = it.get_group (0 );
432
- size_t reduction_batch_id = it. get_group ( 1 );
433
- size_t reduction_lid = it.get_local_id (1 );
434
- size_t wg_size = it.get_local_range (1 );
436
+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
437
+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
438
+ const size_t reduction_batch_id = get_reduction_batch_id (it );
439
+ const size_t reduction_lid = it.get_local_id (0 );
440
+ const size_t wg_size = it.get_local_range (0 );
435
441
436
442
auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (reduction_id);
437
443
const py::ssize_t &inp_iter_offset =
@@ -442,26 +448,34 @@ struct StridedBooleanReduction
442
448
outT local_red_val (identity_);
443
449
size_t arg_reduce_gid0 =
444
450
reduction_lid + reduction_batch_id * wg_size * reductions_per_wi;
445
- for (size_t m = 0 ; m < reductions_per_wi; ++m) {
446
- size_t arg_reduce_gid = arg_reduce_gid0 + m * wg_size;
447
-
448
- if (arg_reduce_gid < reduction_max_gid_) {
449
- py::ssize_t inp_reduction_offset = static_cast <py::ssize_t >(
450
- inp_reduced_dims_indexer_ (arg_reduce_gid));
451
- py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
451
+ size_t arg_reduce_gid_max = std::min (
452
+ reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg_size);
453
+ for (size_t arg_reduce_gid = arg_reduce_gid0;
454
+ arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size)
455
+ {
456
+ py::ssize_t inp_reduction_offset = static_cast <py::ssize_t >(
457
+ inp_reduced_dims_indexer_ (arg_reduce_gid));
458
+ py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
452
459
453
- // must convert to boolean first to handle nans
454
- using dpctl::tensor::type_utils::convert_impl;
455
- bool val = convert_impl<bool , argT>(inp_[inp_offset]);
456
- ReductionOp op = reduction_op_;
460
+ // must convert to boolean first to handle nans
461
+ using dpctl::tensor::type_utils::convert_impl;
462
+ bool val = convert_impl<bool , argT>(inp_[inp_offset]);
463
+ ReductionOp op = reduction_op_;
457
464
458
- local_red_val = op (local_red_val, static_cast <outT>(val));
459
- }
465
+ local_red_val = op (local_red_val, static_cast <outT>(val));
460
466
}
461
467
// reduction and atomic operations are performed
462
468
// in group_op_
463
469
group_op_ (it, out_, out_iter_offset, local_red_val);
464
470
}
471
+
472
+ private:
473
+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
474
+ {
475
+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
476
+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
477
+ return reduction_batch_id;
478
+ }
465
479
};
466
480
467
481
template <typename T1,
@@ -564,7 +578,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
564
578
red_ev = exec_q.submit ([&](sycl::handler &cgh) {
565
579
cgh.depends_on (res_init_ev);
566
580
567
- constexpr std::uint8_t group_dim = 2 ;
581
+ constexpr std::uint8_t dim = 1 ;
568
582
569
583
using InputOutputIterIndexerT =
570
584
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -587,20 +601,19 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
587
601
(reduction_nelems + reductions_per_wi * wg - 1 ) /
588
602
(reductions_per_wi * wg);
589
603
590
- auto gws =
591
- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
592
- auto lws = sycl::range<group_dim>{1 , wg};
604
+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
605
+ auto lws = sycl::range<dim>{wg};
593
606
594
607
cgh.parallel_for <class boolean_reduction_strided_krn <
595
608
argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
596
609
ReductionIndexerT>>(
597
- sycl::nd_range<group_dim >(gws, lws),
610
+ sycl::nd_range<dim >(gws, lws),
598
611
StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
599
612
InputOutputIterIndexerT,
600
613
ReductionIndexerT>(
601
614
arg_tp, res_tp, RedOpT (), GroupOpT (), identity_val,
602
615
in_out_iter_indexer, reduction_indexer, reduction_nelems,
603
- reductions_per_wi));
616
+ iter_nelems, reductions_per_wi));
604
617
});
605
618
}
606
619
return red_ev;
0 commit comments