Skip to content

Commit c643acb

Browse files
committed
Rearranges _comparison_over_axis and _search_over_axis to remove fast paths for zero-size arrays
This prevents possible edge cases where an array with a non-zero number of elements could be allocated prior to an error being thrown for a reduction being performed over a size-zero axis
1 parent dd78f36 commit c643acb

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

dpctl/tensor/_reduction.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
443443
axis = tuple(range(nd))
444444
if not isinstance(axis, (tuple, list)):
445445
axis = (axis,)
446+
if any([x.shape[i] == 0 for i in axis]):
447+
raise ValueError("reduction cannot be performed over zero-size axes")
446448
axis = normalize_axis_tuple(axis, nd, "axis")
447449
red_nd = len(axis)
448450
perm = [i for i in range(nd) if i not in axis] + list(axis)
@@ -490,14 +492,6 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
490492
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
491493
)
492494

493-
if x.size == 0:
494-
if any([x.shape[i] == 0 for i in axis]):
495-
raise ValueError(
496-
"reduction cannot be performed over zero-size axes"
497-
)
498-
else:
499-
return out
500-
501495
host_tasks_list = []
502496
if red_nd == 0:
503497
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -615,6 +609,8 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
615609
f"`axis` argument expected `int` or `None`, got {type(axis)}"
616610
)
617611
axis = normalize_axis_tuple(axis, nd, "axis")
612+
if any([x.shape[i] == 0 for i in axis]):
613+
raise ValueError("reduction cannot be performed over zero-size axes")
618614
red_nd = len(axis)
619615
perm = [i for i in range(nd) if i not in axis] + list(axis)
620616
x_tmp = dpt.permute_dims(x, perm)
@@ -661,14 +657,6 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
661657
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
662658
)
663659

664-
if x.size == 0:
665-
if any([x.shape[i] == 0 for i in axis]):
666-
raise ValueError(
667-
"reduction cannot be performed over zero-size axes"
668-
)
669-
else:
670-
return out
671-
672660
if red_nd == 0:
673661
ht_e_fill, _ = ti._full_usm_ndarray(
674662
fill_value=0, dst=out, sycl_queue=exec_q

0 commit comments

Comments
 (0)