Skip to content

Commit c264567

Browse files
committed
Resolves gh-1456
Tree reductions now populate destination with the identity when reducing over zero-size axes. As a result, logic was removed for handling zero-size axes. ``argmax``, ``argmin``, ``max``, and ``min`` still raise an error for zero-size axes. Reductions now return a copy when provided an empty axis tuple. Adds additional supported dtype combinations to ``prod`` and ``sum``, specifically for input integers and inexact output type
1 parent f686102 commit c264567

File tree

3 files changed

+154
-28
lines changed

3 files changed

+154
-28
lines changed

dpctl/tensor/_reduction.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _reduction_over_axis(
8383
_reduction_fn,
8484
_dtype_supported,
8585
_default_reduction_type_fn,
86-
_identity=None,
8786
):
8887
if not isinstance(x, dpt.usm_ndarray):
8988
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -106,23 +105,8 @@ def _reduction_over_axis(
106105
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
107106

108107
res_usm_type = x.usm_type
109-
if x.size == 0:
110-
if _identity is None:
111-
raise ValueError("reduction does not support zero-size arrays")
112-
else:
113-
if keepdims:
114-
res_shape = res_shape + (1,) * red_nd
115-
inv_perm = sorted(range(nd), key=lambda d: perm[d])
116-
res_shape = tuple(res_shape[i] for i in inv_perm)
117-
return dpt.full(
118-
res_shape,
119-
_identity,
120-
dtype=res_dt,
121-
usm_type=res_usm_type,
122-
sycl_queue=q,
123-
)
124108
if red_nd == 0:
125-
return dpt.astype(x, res_dt, copy=False)
109+
return dpt.astype(x, res_dt, copy=True)
126110

127111
host_tasks_list = []
128112
if _dtype_supported(inp_dt, res_dt, res_usm_type, q):
@@ -251,7 +235,6 @@ def sum(x, axis=None, dtype=None, keepdims=False):
251235
tri._sum_over_axis,
252236
tri._sum_over_axis_dtype_supported,
253237
_default_reduction_dtype,
254-
_identity=0,
255238
)
256239

257240

@@ -312,7 +295,6 @@ def prod(x, axis=None, dtype=None, keepdims=False):
312295
tri._prod_over_axis,
313296
tri._prod_over_axis_dtype_supported,
314297
_default_reduction_dtype,
315-
_identity=1,
316298
)
317299

318300

@@ -368,7 +350,6 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
368350
inp_dt, res_dt
369351
),
370352
_default_reduction_dtype_fp_types,
371-
_identity=-dpt.inf,
372353
)
373354

374355

@@ -424,7 +405,6 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
424405
inp_dt, res_dt
425406
),
426407
_default_reduction_dtype_fp_types,
427-
_identity=0,
428408
)
429409

430410

@@ -446,9 +426,19 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
446426
res_dt = x.dtype
447427
res_usm_type = x.usm_type
448428
if x.size == 0:
449-
raise ValueError("reduction does not support zero-size arrays")
429+
if any([x.shape[i] == 0 for i in axis]):
430+
raise ValueError(
431+
"reduction cannot be performed over zero-size axes"
432+
)
433+
else:
434+
return dpt.empty(
435+
res_shape,
436+
dtype=res_dt,
437+
usm_type=res_usm_type,
438+
sycl_queue=exec_q,
439+
)
450440
if red_nd == 0:
451-
return x
441+
return dpt.copy(x)
452442

453443
res = dpt.empty(
454444
res_shape,
@@ -549,7 +539,17 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn):
549539
res_dt = ti.default_device_index_type(exec_q.sycl_device)
550540
res_usm_type = x.usm_type
551541
if x.size == 0:
552-
raise ValueError("reduction does not support zero-size arrays")
542+
if any([x.shape[i] == 0 for i in axis]):
543+
raise ValueError(
544+
"reduction cannot be performed over zero-size axes"
545+
)
546+
else:
547+
return dpt.empty(
548+
res_shape,
549+
dtype=res_dt,
550+
usm_type=res_usm_type,
551+
sycl_queue=exec_q,
552+
)
553553
if red_nd == 0:
554554
return dpt.zeros(
555555
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,9 @@ template <typename T1,
10091009
typename T6>
10101010
class custom_reduction_over_group_temps_strided_krn;
10111011

1012+
template <typename T1, typename T2, typename T3>
1013+
class reduction_over_group_temps_empty_krn;
1014+
10121015
template <typename T1, typename T2, typename T3, typename T4, typename T5>
10131016
class single_reduction_axis0_temps_contig_krn;
10141017

@@ -1120,6 +1123,31 @@ sycl::event reduction_over_group_temps_strided_impl(
11201123

11211124
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
11221125

1126+
if (reduction_nelems == 0) {
1127+
sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) {
1128+
using IndexerT =
1129+
dpctl::tensor::offset_utils::UnpackedStridedIndexer;
1130+
1131+
const py::ssize_t *const &res_shape = iter_shape_and_strides;
1132+
const py::ssize_t *const &res_strides =
1133+
iter_shape_and_strides + 2 * iter_nd;
1134+
IndexerT res_indexer(iter_nd, iter_res_offset, res_shape,
1135+
res_strides);
1136+
using InitKernelName =
1137+
class reduction_over_group_temps_empty_krn<resTy, argTy,
1138+
ReductionOpT>;
1139+
cgh.depends_on(depends);
1140+
1141+
cgh.parallel_for<InitKernelName>(
1142+
sycl::range<1>(iter_nelems), [=](sycl::id<1> id) {
1143+
auto res_offset = res_indexer(id[0]);
1144+
res_tp[res_offset] = identity_val;
1145+
});
1146+
});
1147+
1148+
return res_init_ev;
1149+
}
1150+
11231151
const sycl::device &d = exec_q.get_device();
11241152
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
11251153
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
@@ -1244,7 +1272,7 @@ sycl::event reduction_over_group_temps_strided_impl(
12441272
resTy *partially_reduced_tmp2 = nullptr;
12451273

12461274
if (partially_reduced_tmp == nullptr) {
1247-
throw std::runtime_error("Unabled to allocate device_memory");
1275+
throw std::runtime_error("Unable to allocate device_memory");
12481276
}
12491277
else {
12501278
partially_reduced_tmp2 =
@@ -1501,6 +1529,13 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
15011529

15021530
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
15031531

1532+
if (reduction_nelems == 0) {
1533+
sycl::event res_init_ev = exec_q.fill<resTy>(
1534+
res_tp, resTy(identity_val), iter_nelems, depends);
1535+
1536+
return res_init_ev;
1537+
}
1538+
15041539
const sycl::device &d = exec_q.get_device();
15051540
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
15061541
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
@@ -1632,7 +1667,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
16321667
resTy *partially_reduced_tmp2 = nullptr;
16331668

16341669
if (partially_reduced_tmp == nullptr) {
1635-
throw std::runtime_error("Unabled to allocate device_memory");
1670+
throw std::runtime_error("Unable to allocate device_memory");
16361671
}
16371672
else {
16381673
partially_reduced_tmp2 =
@@ -1879,6 +1914,13 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
18791914

18801915
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
18811916

1917+
if (reduction_nelems == 0) {
1918+
sycl::event res_init_ev = exec_q.fill<resTy>(
1919+
res_tp, resTy(identity_val), iter_nelems, depends);
1920+
1921+
return res_init_ev;
1922+
}
1923+
18821924
const sycl::device &d = exec_q.get_device();
18831925
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
18841926
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
@@ -2015,7 +2057,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
20152057
resTy *partially_reduced_tmp2 = nullptr;
20162058

20172059
if (partially_reduced_tmp == nullptr) {
2018-
throw std::runtime_error("Unabled to allocate device_memory");
2060+
throw std::runtime_error("Unable to allocate device_memory");
20192061
}
20202062
else {
20212063
partially_reduced_tmp2 =
@@ -2712,12 +2754,16 @@ struct TypePairSupportDataForSumReductionTemps
27122754
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
27132755
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
27142756
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
2757+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
2758+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
27152759

27162760
// input int8_t
27172761
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
27182762
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int16_t>,
27192763
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
27202764
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
2765+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
2766+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
27212767

27222768
// input uint8_t
27232769
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
@@ -2727,32 +2773,44 @@ struct TypePairSupportDataForSumReductionTemps
27272773
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
27282774
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
27292775
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
2776+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
2777+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
27302778

27312779
// input int16_t
27322780
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
27332781
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
27342782
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
2783+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
2784+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
27352785

27362786
// input uint16_t
27372787
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
27382788
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
27392789
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
27402790
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
27412791
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
2792+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
2793+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
27422794

27432795
// input int32_t
27442796
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
27452797
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
2798+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
2799+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
27462800

27472801
// input uint32_t
27482802
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
27492803
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
2804+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
2805+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
27502806

27512807
// input int64_t
27522808
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
2809+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
27532810

2754-
// input uint32_t
2811+
// input uint64_t
27552812
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
2813+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
27562814

27572815
// input half
27582816
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
@@ -2967,12 +3025,16 @@ struct TypePairSupportDataForProductReductionTemps
29673025
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint32_t>,
29683026
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
29693027
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::uint64_t>,
3028+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
3029+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
29703030

29713031
// input int8_t
29723032
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
29733033
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int16_t>,
29743034
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
29753035
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
3036+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
3037+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
29763038

29773039
// input uint8_t
29783040
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
@@ -2982,32 +3044,44 @@ struct TypePairSupportDataForProductReductionTemps
29823044
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
29833045
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
29843046
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
3047+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
3048+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
29853049

29863050
// input int16_t
29873051
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
29883052
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
29893053
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
3054+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
3055+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
29903056

29913057
// input uint16_t
29923058
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
29933059
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int32_t>,
29943060
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
29953061
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
29963062
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
3063+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
3064+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
29973065

29983066
// input int32_t
29993067
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
30003068
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
3069+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
3070+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
30013071

30023072
// input uint32_t
30033073
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
30043074
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
3075+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
3076+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
30053077

30063078
// input int64_t
30073079
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
3080+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
30083081

30093082
// input uint32_t
30103083
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
3084+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
30113085

30123086
// input half
30133087
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
@@ -3957,6 +4031,8 @@ template <typename T1,
39574031
bool b2>
39584032
class custom_search_over_group_temps_strided_krn;
39594033

4034+
template <typename T1, typename T2, typename T3> class search_empty_krn;
4035+
39604036
template <typename T1,
39614037
typename T2,
39624038
typename T3,
@@ -4160,6 +4236,30 @@ sycl::event search_over_group_temps_strided_impl(
41604236
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
41614237
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
41624238

4239+
if (reduction_nelems == 0) {
4240+
sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) {
4241+
using IndexerT =
4242+
dpctl::tensor::offset_utils::UnpackedStridedIndexer;
4243+
4244+
const py::ssize_t *const &res_shape = iter_shape_and_strides;
4245+
const py::ssize_t *const &res_strides =
4246+
iter_shape_and_strides + 2 * iter_nd;
4247+
IndexerT res_indexer(iter_nd, iter_res_offset, res_shape,
4248+
res_strides);
4249+
using InitKernelName =
4250+
class search_empty_krn<resTy, argTy, ReductionOpT>;
4251+
cgh.depends_on(depends);
4252+
4253+
cgh.parallel_for<InitKernelName>(
4254+
sycl::range<1>(iter_nelems), [=](sycl::id<1> id) {
4255+
auto res_offset = res_indexer(id[0]);
4256+
res_tp[res_offset] = idx_identity_val;
4257+
});
4258+
});
4259+
4260+
return res_init_ev;
4261+
}
4262+
41634263
const sycl::device &d = exec_q.get_device();
41644264
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
41654265
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
@@ -4590,6 +4690,13 @@ sycl::event search_axis1_over_group_temps_contig_impl(
45904690
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
45914691
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
45924692

4693+
if (reduction_nelems == 0) {
4694+
sycl::event res_init_ev = exec_q.fill<resTy>(
4695+
res_tp, resTy(idx_identity_val), iter_nelems, depends);
4696+
4697+
return res_init_ev;
4698+
}
4699+
45934700
const sycl::device &d = exec_q.get_device();
45944701
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
45954702
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);
@@ -5005,6 +5112,13 @@ sycl::event search_axis0_over_group_temps_contig_impl(
50055112
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
50065113
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
50075114

5115+
if (reduction_nelems == 0) {
5116+
sycl::event res_init_ev = exec_q.fill<resTy>(
5117+
res_tp, resTy(idx_identity_val), iter_nelems, depends);
5118+
5119+
return res_init_ev;
5120+
}
5121+
50085122
const sycl::device &d = exec_q.get_device();
50095123
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
50105124
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

0 commit comments

Comments
 (0)