Skip to content

Commit 16b631f

Browse files
committed
Applies suggestions per PR review
Avoids a call to permute_dims in reductions when `axis=None` Make reduction tests more efficient by reusing an array of zeros rather than reconstructing
1 parent 8ab37a1 commit 16b631f

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

dpctl/tensor/_reduction.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,15 @@ def _reduction_over_axis(
4444
nd = x.ndim
4545
if axis is None:
4646
axis = tuple(range(nd))
47-
if not isinstance(axis, (tuple, list)):
48-
axis = (axis,)
49-
axis = normalize_axis_tuple(axis, nd, "axis")
47+
perm = list(axis)
48+
arr = x
49+
else:
50+
if not isinstance(axis, (tuple, list)):
51+
axis = (axis,)
52+
axis = normalize_axis_tuple(axis, nd, "axis")
53+
perm = [i for i in range(nd) if i not in axis] + list(axis)
54+
arr = dpt.permute_dims(x, perm)
5055
red_nd = len(axis)
51-
perm = [i for i in range(nd) if i not in axis] + list(axis)
52-
arr = dpt.permute_dims(x, perm)
5356
res_shape = arr.shape[: nd - red_nd]
5457
q = x.sycl_queue
5558
inp_dt = x.dtype
@@ -89,7 +92,7 @@ def _reduction_over_axis(
8992
)
9093
if res_dt != out.dtype:
9194
raise ValueError(
92-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
95+
f"Output array of type {res_dt} is needed, got {out.dtype}"
9396
)
9497
if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None:
9598
raise ExecutionPlacementError(
@@ -441,14 +444,17 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
441444
nd = x.ndim
442445
if axis is None:
443446
axis = tuple(range(nd))
444-
if not isinstance(axis, (tuple, list)):
445-
axis = (axis,)
446-
if any([x.shape[i] == 0 for i in axis]):
447-
raise ValueError("reduction cannot be performed over zero-size axes")
448-
axis = normalize_axis_tuple(axis, nd, "axis")
447+
perm = list(axis)
448+
x_tmp = x
449+
else:
450+
if not isinstance(axis, (tuple, list)):
451+
axis = (axis,)
452+
axis = normalize_axis_tuple(axis, nd, "axis")
453+
perm = [i for i in range(nd) if i not in axis] + list(axis)
454+
x_tmp = dpt.permute_dims(x, perm)
449455
red_nd = len(axis)
450-
perm = [i for i in range(nd) if i not in axis] + list(axis)
451-
x_tmp = dpt.permute_dims(x, perm)
456+
if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]):
457+
raise ValueError("reduction cannot be performed over zero-size axes")
452458
res_shape = x_tmp.shape[: nd - red_nd]
453459
exec_q = x.sycl_queue
454460
res_dt = x.dtype
@@ -476,7 +482,7 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
476482
)
477483
if res_dt != out.dtype:
478484
raise ValueError(
479-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
485+
f"Output array of type {res_dt} is needed, got {out.dtype}"
480486
)
481487
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
482488
raise ExecutionPlacementError(
@@ -602,18 +608,22 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
602608
nd = x.ndim
603609
if axis is None:
604610
axis = tuple(range(nd))
605-
elif isinstance(axis, int):
606-
axis = (axis,)
611+
perm = list(axis)
612+
x_tmp = x
607613
else:
608-
raise TypeError(
609-
f"`axis` argument expected `int` or `None`, got {type(axis)}"
610-
)
614+
if isinstance(axis, int):
615+
axis = (axis,)
616+
else:
617+
raise TypeError(
618+
f"`axis` argument expected `int` or `None`, got {type(axis)}"
619+
)
620+
axis = normalize_axis_tuple(axis, nd, "axis")
621+
perm = [i for i in range(nd) if i not in axis] + list(axis)
622+
x_tmp = dpt.permute_dims(x, perm)
611623
axis = normalize_axis_tuple(axis, nd, "axis")
612-
if any([x.shape[i] == 0 for i in axis]):
613-
raise ValueError("reduction cannot be performed over zero-size axes")
614624
red_nd = len(axis)
615-
perm = [i for i in range(nd) if i not in axis] + list(axis)
616-
x_tmp = dpt.permute_dims(x, perm)
625+
if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]):
626+
raise ValueError("reduction cannot be performed over zero-size axes")
617627
res_shape = x_tmp.shape[: nd - red_nd]
618628
exec_q = x.sycl_queue
619629
res_dt = ti.default_device_index_type(exec_q.sycl_device)
@@ -641,7 +651,7 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
641651
)
642652
if res_dt != out.dtype:
643653
raise ValueError(
644-
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
654+
f"Output array of type {res_dt} is needed, got {out.dtype}"
645655
)
646656
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
647657
raise ExecutionPlacementError(

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,9 @@ def test_numeric_reduction_out_kwarg():
537537
x = dpt.ones((n1, n2, n3), dtype="i4")
538538
out = dpt.zeros((2 * n1, 3 * n2), dtype="f4")
539539
res = dpt.sum(x, axis=-1, dtype="f4", out=out[::-2, 1::3])
540-
assert dpt.allclose(out[::-2, 0::3], dpt.zeros_like(res))
541-
assert dpt.allclose(out[::-2, 2::3], dpt.zeros_like(res))
540+
zero_res = dpt.zeros_like(res)
541+
assert dpt.allclose(out[::-2, 0::3], zero_res)
542+
assert dpt.allclose(out[::-2, 2::3], zero_res)
542543
assert dpt.allclose(out[::-2, 1::3], res)
543544
assert dpt.allclose(out[::-2, 1::3], dpt.full_like(res, 5, dtype="f4"))
544545

0 commit comments

Comments
 (0)