Skip to content

Commit 94bea60

Browse files
authored
Add nd-support to dpnp.trim_zeros (#2241)
The PR extends `dpnp.trim_zeros` implement to align with changes introduced in NumPy 2.2. It adds support for trimming nd-arrays while preserving the old behavior for 1-d input. The new parameter `axis` can specify a single dimension to be trimmed (reducing all other dimensions to the envelope of absolute values). By default (when `None` is specified), all dimensions are trimmed iteratively.
1 parent fdde3d0 commit 94bea60

File tree

3 files changed

+100
-42
lines changed

3 files changed

+100
-42
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
39003900
permute_dims = transpose # permute_dims is an alias for transpose
39013901

39023902

3903-
def trim_zeros(filt, trim="fb"):
3903+
def trim_zeros(filt, trim="fb", axis=None):
39043904
"""
3905-
Trim the leading and/or trailing zeros from a 1-D array.
3905+
Remove values along a dimension which are zero along all other.
39063906
39073907
For full documentation refer to :obj:`numpy.trim_zeros`.
39083908
39093909
Parameters
39103910
----------
39113911
filt : {dpnp.ndarray, usm_ndarray}
3912-
Input 1-D array.
3913-
trim : str, optional
3914-
A string with 'f' representing trim from front and 'b' to trim from
3915-
back. By defaults, trim zeros from both front and back of the array.
3912+
Input array.
3913+
trim : {"fb", "f", "b"}, optional
3914+
A string with `"f"` representing trim from front and `"b"` to trim from
3915+
back. By default, zeros are trimmed on both sides. Front and back refer
3916+
to the edges of a dimension, with "front" referring to the side with
3917+
the lowest index 0, and "back" referring to the highest index
3918+
(or index -1).
39163919
Default: ``"fb"``.
3920+
axis : {None, int}, optional
3921+
If ``None``, `filt` is cropped such, that the smallest bounding box is
3922+
returned that still contains all values which are not zero.
3923+
If an `axis` is specified, `filt` will be sliced in that dimension only
3924+
on the sides specified by `trim`. The remaining area will be the
3925+
smallest that still contains all values which are not zero.
3926+
Default: ``None``.
39173927
39183928
Returns
39193929
-------
39203930
out : dpnp.ndarray
3921-
The result of trimming the input.
3931+
The result of trimming the input. The number of dimensions and the
3932+
input data type are preserved.
3933+
3934+
Notes
3935+
-----
3936+
For all-zero arrays, the first axis is trimmed first.
39223937
39233938
Examples
39243939
--------
@@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
39273942
>>> np.trim_zeros(a)
39283943
array([1, 2, 3, 0, 2, 1])
39293944
3930-
>>> np.trim_zeros(a, 'b')
3945+
>>> np.trim_zeros(a, trim='b')
39313946
array([0, 0, 0, 1, 2, 3, 0, 2, 1])
39323947
3948+
Multiple dimensions are supported:
3949+
3950+
>>> b = np.array([[0, 0, 2, 3, 0, 0],
3951+
... [0, 1, 0, 3, 0, 0],
3952+
... [0, 0, 0, 0, 0, 0]])
3953+
>>> np.trim_zeros(b)
3954+
array([[0, 2, 3],
3955+
[1, 0, 3]])
3956+
3957+
>>> np.trim_zeros(b, axis=-1)
3958+
array([[0, 2, 3],
3959+
[1, 0, 3],
3960+
[0, 0, 0]])
3961+
39333962
"""
39343963

39353964
dpnp.check_supported_arrays_type(filt)
3936-
if filt.ndim == 0:
3937-
raise TypeError("0-d array cannot be trimmed")
3938-
if filt.ndim > 1:
3939-
raise ValueError("Multi-dimensional trim is not supported")
39403965

39413966
if not isinstance(trim, str):
39423967
raise TypeError("only string trim is supported")
39433968

3944-
trim = trim.upper()
3945-
if not any(x in trim for x in "FB"):
3946-
return filt # no trim rule is specified
3969+
trim = trim.lower()
3970+
if trim not in ["fb", "bf", "f", "b"]:
3971+
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")
3972+
3973+
nd = filt.ndim
3974+
if axis is not None:
3975+
axis = normalize_axis_index(axis, nd)
39473976

39483977
if filt.size == 0:
39493978
return filt # no trailing zeros in empty array
39503979

3951-
a = dpnp.nonzero(filt)[0]
3952-
a_size = a.size
3953-
if a_size == 0:
3954-
# 'filt' is array of zeros
3955-
return dpnp.empty_like(filt, shape=(0,))
3980+
non_zero = dpnp.argwhere(filt)
3981+
if non_zero.size == 0:
3982+
# `filt` has all zeros, so assign `start` and `stop` to the same value,
3983+
# then the resulting slice will be empty
3984+
start = stop = dpnp.zeros_like(filt, shape=nd, dtype=dpnp.intp)
3985+
else:
3986+
if "f" in trim:
3987+
start = non_zero.min(axis=0)
3988+
else:
3989+
start = (None,) * nd
39563990

3957-
first = 0
3958-
if "F" in trim:
3959-
first = a[0]
3991+
if "b" in trim:
3992+
stop = non_zero.max(axis=0)
3993+
stop += 1 # Adjust for slicing
3994+
else:
3995+
stop = (None,) * nd
39603996

3961-
last = filt.size
3962-
if "B" in trim:
3963-
last = a[-1] + 1
3997+
if axis is None:
3998+
# trim all axes
3999+
sl = tuple(slice(*x) for x in zip(start, stop))
4000+
else:
4001+
# only trim single axis
4002+
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
39644003

3965-
return filt[first:last]
4004+
return filt[sl]
39664005

39674006

39684007
def unique(

dpnp/tests/test_manipulation.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,20 @@ def test_basic(self, dtype):
13781378
expected = numpy.trim_zeros(a)
13791379
assert_array_equal(result, expected)
13801380

1381+
@testing.with_requires("numpy>=2.2")
1382+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1383+
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
1384+
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
1385+
def test_basic_nd(self, dtype, trim, ndim):
1386+
a = numpy.ones((2,) * ndim, dtype=dtype)
1387+
a = numpy.pad(a, (2, 1), mode="constant", constant_values=0)
1388+
ia = dpnp.array(a)
1389+
1390+
for axis in list(range(ndim)) + [None]:
1391+
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
1392+
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
1393+
assert_array_equal(result, expected)
1394+
13811395
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
13821396
@pytest.mark.parametrize("trim", ["F", "B"])
13831397
def test_trim(self, dtype, trim):
@@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim):
13981412
expected = numpy.trim_zeros(a, trim)
13991413
assert_array_equal(result, expected)
14001414

1415+
@testing.with_requires("numpy>=2.2")
1416+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1417+
@pytest.mark.parametrize("trim", ["F", "B", "fb"])
1418+
@pytest.mark.parametrize("ndim", [0, 1, 2, 3])
1419+
def test_all_zero_nd(self, dtype, trim, ndim):
1420+
a = numpy.zeros((3,) * ndim, dtype=dtype)
1421+
ia = dpnp.array(a)
1422+
1423+
for axis in list(range(ndim)) + [None]:
1424+
result = dpnp.trim_zeros(ia, trim=trim, axis=axis)
1425+
expected = numpy.trim_zeros(a, trim=trim, axis=axis)
1426+
assert_array_equal(result, expected)
1427+
14011428
def test_size_zero(self):
14021429
a = numpy.zeros(0)
14031430
ia = dpnp.array(a)
@@ -1416,17 +1443,11 @@ def test_overflow(self, a):
14161443
expected = numpy.trim_zeros(a)
14171444
assert_array_equal(result, expected)
14181445

1419-
# TODO: modify once SAT-7616
1420-
# numpy 2.2 validates trim rules
1421-
@testing.with_requires("numpy<2.2")
1422-
def test_trim_no_rule(self):
1423-
a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0])
1424-
ia = dpnp.array(a)
1425-
trim = "ADE" # no "F" or "B" in trim string
1426-
1427-
result = dpnp.trim_zeros(ia, trim)
1428-
expected = numpy.trim_zeros(a, trim)
1429-
assert_array_equal(result, expected)
1446+
@testing.with_requires("numpy>=2.2")
1447+
@pytest.mark.parametrize("xp", [numpy, dpnp])
1448+
def test_trim_no_fb_in_rule(self, xp):
1449+
a = xp.array([0, 0, 1, 0, 2, 3, 4, 0])
1450+
assert_raises(ValueError, xp.trim_zeros, a, "ADE")
14301451

14311452
def test_list_array(self):
14321453
assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0])

dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,15 @@ def test_trim_back_zeros(self, xp, dtype):
387387
a = xp.array([1, 0, 2, 3, 0, 5, 0, 0, 0], dtype=dtype)
388388
return xp.trim_zeros(a, trim=self.trim)
389389

390-
# TODO: remove once SAT-7616
391-
@testing.with_requires("numpy<2.2")
390+
@pytest.mark.skip("0-d array is supported")
392391
@testing.for_all_dtypes()
393392
def test_trim_zero_dim(self, dtype):
394393
for xp in (numpy, cupy):
395394
a = testing.shaped_arange((), xp, dtype)
396395
with pytest.raises(TypeError):
397396
xp.trim_zeros(a, trim=self.trim)
398397

399-
# TODO: remove once SAT-7616
400-
@testing.with_requires("numpy<2.2")
398+
@pytest.mark.skip("nd array is supported")
401399
@testing.for_all_dtypes()
402400
def test_trim_ndim(self, dtype):
403401
for xp in (numpy, cupy):

0 commit comments

Comments
 (0)