Skip to content

Commit b008b8b

Browse files
Merge pull request #1303 from IntelPython/reduction-changes
Improvement to performance of tensor.sum
2 parents 2f3be1f + 9f54428 commit b008b8b

File tree

3 files changed

+152
-102
lines changed

3 files changed

+152
-102
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,12 @@ template <typename T> struct boolean_predicate
5555
}
5656
};
5757

58-
template <typename inpT,
59-
typename outT,
60-
typename PredicateT,
61-
std::uint8_t wg_dim = 2>
58+
template <typename inpT, typename outT, typename PredicateT>
6259
struct all_reduce_wg_contig
6360
{
64-
void operator()(sycl::nd_item<wg_dim> &ndit,
61+
void operator()(sycl::nd_item<1> &ndit,
6562
outT *out,
66-
size_t &out_idx,
63+
const size_t &out_idx,
6764
const inpT *start,
6865
const inpT *end) const
6966
{
@@ -82,15 +79,12 @@ struct all_reduce_wg_contig
8279
}
8380
};
8481

85-
template <typename inpT,
86-
typename outT,
87-
typename PredicateT,
88-
std::uint8_t wg_dim = 2>
82+
template <typename inpT, typename outT, typename PredicateT>
8983
struct any_reduce_wg_contig
9084
{
91-
void operator()(sycl::nd_item<wg_dim> &ndit,
85+
void operator()(sycl::nd_item<1> &ndit,
9286
outT *out,
93-
size_t &out_idx,
87+
const size_t &out_idx,
9488
const inpT *start,
9589
const inpT *end) const
9690
{
@@ -109,9 +103,9 @@ struct any_reduce_wg_contig
109103
}
110104
};
111105

112-
template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
106+
template <typename T> struct all_reduce_wg_strided
113107
{
114-
void operator()(sycl::nd_item<wg_dim> &ndit,
108+
void operator()(sycl::nd_item<1> &ndit,
115109
T *out,
116110
const size_t &out_idx,
117111
const T &local_val) const
@@ -129,9 +123,9 @@ template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
129123
}
130124
};
131125

132-
template <typename T, std::uint8_t wg_dim = 2> struct any_reduce_wg_strided
126+
template <typename T> struct any_reduce_wg_strided
133127
{
134-
void operator()(sycl::nd_item<wg_dim> &ndit,
128+
void operator()(sycl::nd_item<1> &ndit,
135129
T *out,
136130
const size_t &out_idx,
137131
const T &local_val) const
@@ -215,35 +209,46 @@ struct ContigBooleanReduction
215209
outT *out_ = nullptr;
216210
GroupOp group_op_;
217211
size_t reduction_max_gid_ = 0;
212+
size_t iter_gws_ = 1;
218213
size_t reductions_per_wi = 16;
219214

220215
public:
221216
ContigBooleanReduction(const argT *inp,
222217
outT *res,
223218
GroupOp group_op,
224219
size_t reduction_size,
220+
size_t iteration_size,
225221
size_t reduction_size_per_wi)
226222
: inp_(inp), out_(res), group_op_(group_op),
227-
reduction_max_gid_(reduction_size),
223+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
228224
reductions_per_wi(reduction_size_per_wi)
229225
{
230226
}
231227

232-
void operator()(sycl::nd_item<2> it) const
228+
void operator()(sycl::nd_item<1> it) const
233229
{
234-
235-
size_t reduction_id = it.get_group(0);
236-
size_t reduction_batch_id = it.get_group(1);
237-
size_t wg_size = it.get_local_range(1);
238-
239-
size_t base = reduction_id * reduction_max_gid_;
240-
size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
241-
size_t end = std::min((start + (reductions_per_wi * wg_size)),
242-
base + reduction_max_gid_);
230+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
231+
const size_t reduction_id = it.get_global_id(0) / red_gws_;
232+
const size_t reduction_batch_id = get_reduction_batch_id(it);
233+
const size_t wg_size = it.get_local_range(0);
234+
235+
const size_t base = reduction_id * reduction_max_gid_;
236+
const size_t start =
237+
base + reduction_batch_id * wg_size * reductions_per_wi;
238+
const size_t end = std::min((start + (reductions_per_wi * wg_size)),
239+
base + reduction_max_gid_);
243240
// reduction and atomic operations are performed
244241
// in group_op_
245242
group_op_(it, out_, reduction_id, inp_ + start, inp_ + end);
246243
}
244+
245+
private:
246+
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
247+
{
248+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
249+
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
250+
return reduction_batch_id;
251+
}
247252
};
248253

249254
typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -332,7 +337,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
332337
red_ev = exec_q.submit([&](sycl::handler &cgh) {
333338
cgh.depends_on(init_ev);
334339

335-
constexpr std::uint8_t group_dim = 2;
340+
constexpr std::uint8_t dim = 1;
336341

337342
constexpr size_t preferred_reductions_per_wi = 4;
338343
size_t reductions_per_wi =
@@ -344,15 +349,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
344349
(reduction_nelems + reductions_per_wi * wg - 1) /
345350
(reductions_per_wi * wg);
346351

347-
auto gws =
348-
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
349-
auto lws = sycl::range<group_dim>{1, wg};
352+
auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
353+
auto lws = sycl::range<dim>{wg};
350354

351355
cgh.parallel_for<
352356
class boolean_reduction_contig_krn<argTy, resTy, GroupOpT>>(
353-
sycl::nd_range<group_dim>(gws, lws),
357+
sycl::nd_range<dim>(gws, lws),
354358
ContigBooleanReduction<argTy, resTy, GroupOpT>(
355-
arg_tp, res_tp, GroupOpT(), reduction_nelems,
359+
arg_tp, res_tp, GroupOpT(), reduction_nelems, iter_nelems,
356360
reductions_per_wi));
357361
});
358362
}
@@ -404,6 +408,7 @@ struct StridedBooleanReduction
404408
InputOutputIterIndexerT inp_out_iter_indexer_;
405409
InputRedIndexerT inp_reduced_dims_indexer_;
406410
size_t reduction_max_gid_ = 0;
411+
size_t iter_gws_ = 1;
407412
size_t reductions_per_wi = 16;
408413

409414
public:
@@ -415,23 +420,24 @@ struct StridedBooleanReduction
415420
InputOutputIterIndexerT arg_res_iter_indexer,
416421
InputRedIndexerT arg_reduced_dims_indexer,
417422
size_t reduction_size,
423+
size_t iteration_size,
418424
size_t reduction_size_per_wi)
419425
: inp_(inp), out_(res), reduction_op_(reduction_op),
420426
group_op_(group_op), identity_(identity_val),
421427
inp_out_iter_indexer_(arg_res_iter_indexer),
422428
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
423-
reduction_max_gid_(reduction_size),
429+
reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
424430
reductions_per_wi(reduction_size_per_wi)
425431
{
426432
}
427433

428-
void operator()(sycl::nd_item<2> it) const
434+
void operator()(sycl::nd_item<1> it) const
429435
{
430-
431-
size_t reduction_id = it.get_group(0);
432-
size_t reduction_batch_id = it.get_group(1);
433-
size_t reduction_lid = it.get_local_id(1);
434-
size_t wg_size = it.get_local_range(1);
436+
const size_t red_gws_ = it.get_global_range(0) / iter_gws_;
437+
const size_t reduction_id = it.get_global_id(0) / red_gws_;
438+
const size_t reduction_batch_id = get_reduction_batch_id(it);
439+
const size_t reduction_lid = it.get_local_id(0);
440+
const size_t wg_size = it.get_local_range(0);
435441

436442
auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id);
437443
const py::ssize_t &inp_iter_offset =
@@ -442,26 +448,34 @@ struct StridedBooleanReduction
442448
outT local_red_val(identity_);
443449
size_t arg_reduce_gid0 =
444450
reduction_lid + reduction_batch_id * wg_size * reductions_per_wi;
445-
for (size_t m = 0; m < reductions_per_wi; ++m) {
446-
size_t arg_reduce_gid = arg_reduce_gid0 + m * wg_size;
447-
448-
if (arg_reduce_gid < reduction_max_gid_) {
449-
py::ssize_t inp_reduction_offset = static_cast<py::ssize_t>(
450-
inp_reduced_dims_indexer_(arg_reduce_gid));
451-
py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
451+
size_t arg_reduce_gid_max = std::min(
452+
reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg_size);
453+
for (size_t arg_reduce_gid = arg_reduce_gid0;
454+
arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size)
455+
{
456+
py::ssize_t inp_reduction_offset = static_cast<py::ssize_t>(
457+
inp_reduced_dims_indexer_(arg_reduce_gid));
458+
py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
452459

453-
// must convert to boolean first to handle nans
454-
using dpctl::tensor::type_utils::convert_impl;
455-
bool val = convert_impl<bool, argT>(inp_[inp_offset]);
456-
ReductionOp op = reduction_op_;
460+
// must convert to boolean first to handle nans
461+
using dpctl::tensor::type_utils::convert_impl;
462+
bool val = convert_impl<bool, argT>(inp_[inp_offset]);
463+
ReductionOp op = reduction_op_;
457464

458-
local_red_val = op(local_red_val, static_cast<outT>(val));
459-
}
465+
local_red_val = op(local_red_val, static_cast<outT>(val));
460466
}
461467
// reduction and atomic operations are performed
462468
// in group_op_
463469
group_op_(it, out_, out_iter_offset, local_red_val);
464470
}
471+
472+
private:
473+
size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const
474+
{
475+
const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_;
476+
const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups;
477+
return reduction_batch_id;
478+
}
465479
};
466480

467481
template <typename T1,
@@ -564,7 +578,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
564578
red_ev = exec_q.submit([&](sycl::handler &cgh) {
565579
cgh.depends_on(res_init_ev);
566580

567-
constexpr std::uint8_t group_dim = 2;
581+
constexpr std::uint8_t dim = 1;
568582

569583
using InputOutputIterIndexerT =
570584
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -587,20 +601,19 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
587601
(reduction_nelems + reductions_per_wi * wg - 1) /
588602
(reductions_per_wi * wg);
589603

590-
auto gws =
591-
sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
592-
auto lws = sycl::range<group_dim>{1, wg};
604+
auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
605+
auto lws = sycl::range<dim>{wg};
593606

594607
cgh.parallel_for<class boolean_reduction_strided_krn<
595608
argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
596609
ReductionIndexerT>>(
597-
sycl::nd_range<group_dim>(gws, lws),
610+
sycl::nd_range<dim>(gws, lws),
598611
StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
599612
InputOutputIterIndexerT,
600613
ReductionIndexerT>(
601614
arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val,
602615
in_out_iter_indexer, reduction_indexer, reduction_nelems,
603-
reductions_per_wi));
616+
iter_nelems, reductions_per_wi));
604617
});
605618
}
606619
return red_ev;

0 commit comments

Comments
 (0)