@@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
3900
3900
permute_dims = transpose # permute_dims is an alias for transpose
3901
3901
3902
3902
3903
- def trim_zeros (filt , trim = "fb" ):
3903
+ def trim_zeros (filt , trim = "fb" , axis = None ):
3904
3904
"""
3905
- Trim the leading and/or trailing zeros from a 1-D array .
3905
+ Remove values along a dimension which are zero along all other .
3906
3906
3907
3907
For full documentation refer to :obj:`numpy.trim_zeros`.
3908
3908
3909
3909
Parameters
3910
3910
----------
3911
3911
filt : {dpnp.ndarray, usm_ndarray}
3912
- Input 1-D array.
3913
- trim : str, optional
3914
- A string with 'f' representing trim from front and 'b' to trim from
3915
- back. By defaults, trim zeros from both front and back of the array.
3912
+ Input array.
3913
+ trim : {"fb", "f", "b"}, optional
3914
+ A string with `"f"` representing trim from front and `"b"` to trim from
3915
+ back. By default, zeros are trimmed on both sides. Front and back refer
3916
+ to the edges of a dimension, with "front" referring to the side with
3917
+ the lowest index 0, and "back" referring to the highest index
3918
+ (or index -1).
3916
3919
Default: ``"fb"``.
3920
+ axis : {None, int}, optional
3921
+ If ``None``, `filt` is cropped such, that the smallest bounding box is
3922
+ returned that still contains all values which are not zero.
3923
+ If an `axis` is specified, `filt` will be sliced in that dimension only
3924
+ on the sides specified by `trim`. The remaining area will be the
3925
+ smallest that still contains all values which are not zero.
3926
+ Default: ``None``.
3917
3927
3918
3928
Returns
3919
3929
-------
3920
3930
out : dpnp.ndarray
3921
- The result of trimming the input.
3931
+ The result of trimming the input. The number of dimensions and the
3932
+ input data type are preserved.
3933
+
3934
+ Notes
3935
+ -----
3936
+ For all-zero arrays, the first axis is trimmed first.
3922
3937
3923
3938
Examples
3924
3939
--------
@@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
3927
3942
>>> np.trim_zeros(a)
3928
3943
array([1, 2, 3, 0, 2, 1])
3929
3944
3930
- >>> np.trim_zeros(a, 'b')
3945
+ >>> np.trim_zeros(a, trim= 'b')
3931
3946
array([0, 0, 0, 1, 2, 3, 0, 2, 1])
3932
3947
3948
+ Multiple dimensions are supported:
3949
+
3950
+ >>> b = np.array([[0, 0, 2, 3, 0, 0],
3951
+ ... [0, 1, 0, 3, 0, 0],
3952
+ ... [0, 0, 0, 0, 0, 0]])
3953
+ >>> np.trim_zeros(b)
3954
+ array([[0, 2, 3],
3955
+ [1, 0, 3]])
3956
+
3957
+ >>> np.trim_zeros(b, axis=-1)
3958
+ array([[0, 2, 3],
3959
+ [1, 0, 3],
3960
+ [0, 0, 0]])
3961
+
3933
3962
"""
3934
3963
3935
3964
dpnp .check_supported_arrays_type (filt )
3936
- if filt .ndim == 0 :
3937
- raise TypeError ("0-d array cannot be trimmed" )
3938
- if filt .ndim > 1 :
3939
- raise ValueError ("Multi-dimensional trim is not supported" )
3940
3965
3941
3966
if not isinstance (trim , str ):
3942
3967
raise TypeError ("only string trim is supported" )
3943
3968
3944
- trim = trim .upper ()
3945
- if not any (x in trim for x in "FB" ):
3946
- return filt # no trim rule is specified
3969
+ trim = trim .lower ()
3970
+ if trim not in ["fb" , "bf" , "f" , "b" ]:
3971
+ raise ValueError (f"unexpected character(s) in `trim`: { trim !r} " )
3972
+
3973
+ nd = filt .ndim
3974
+ if axis is not None :
3975
+ axis = normalize_axis_index (axis , nd )
3947
3976
3948
3977
if filt .size == 0 :
3949
3978
return filt # no trailing zeros in empty array
3950
3979
3951
- a = dpnp .nonzero (filt )[0 ]
3952
- a_size = a .size
3953
- if a_size == 0 :
3954
- # 'filt' is array of zeros
3955
- return dpnp .empty_like (filt , shape = (0 ,))
3980
+ non_zero = dpnp .argwhere (filt )
3981
+ if non_zero .size == 0 :
3982
+ # `filt` has all zeros, so assign `start` and `stop` to the same value,
3983
+ # then the resulting slice will be empty
3984
+ start = stop = dpnp .zeros_like (filt , shape = nd , dtype = dpnp .intp )
3985
+ else :
3986
+ if "f" in trim :
3987
+ start = non_zero .min (axis = 0 )
3988
+ else :
3989
+ start = (None ,) * nd
3956
3990
3957
- first = 0
3958
- if "F" in trim :
3959
- first = a [0 ]
3991
+ if "b" in trim :
3992
+ stop = non_zero .max (axis = 0 )
3993
+ stop += 1 # Adjust for slicing
3994
+ else :
3995
+ stop = (None ,) * nd
3960
3996
3961
- last = filt .size
3962
- if "B" in trim :
3963
- last = a [- 1 ] + 1
3997
+ if axis is None :
3998
+ # trim all axes
3999
+ sl = tuple (slice (* x ) for x in zip (start , stop ))
4000
+ else :
4001
+ # only trim single axis
4002
+ sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
3964
4003
3965
- return filt [first : last ]
4004
+ return filt [sl ]
3966
4005
3967
4006
3968
4007
def unique (
0 commit comments