Skip to content

Commit 0f403aa

Browse files
vlad-perevezentsevvtavana
authored andcommitted
Add implementation of dpnp.unstack() (#2106)
* Add implementation of dpnp.unstack() * Update manipulation.rst * Add TestUnstack to test_arraymanipulation.py
1 parent 5dd6a62 commit 0f403aa

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
"transpose",
9898
"trim_zeros",
9999
"unique",
100+
"unstack",
100101
"vsplit",
101102
"vstack",
102103
]
@@ -1723,6 +1724,8 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
17231724
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
17241725
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
17251726
size.
1727+
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
1728+
an axis.
17261729
17271730
Examples
17281731
--------
@@ -2913,6 +2916,8 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"):
29132916
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
29142917
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
29152918
size.
2919+
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
2920+
an axis.
29162921
29172922
Examples
29182923
--------
@@ -3413,6 +3418,84 @@ def unique(
34133418
return _unpack_tuple(result)
34143419

34153420

3421+
def unstack(x, /, *, axis=0):
3422+
"""
3423+
Split an array into a sequence of arrays along the given axis.
3424+
3425+
The `axis` parameter specifies the dimension along which the array will
3426+
be split. For example, if ``axis=0`` (the default) it will be the first
3427+
dimension and if ``axis=-1`` it will be the last dimension.
3428+
3429+
The result is a tuple of arrays split along `axis`.
3430+
3431+
For full documentation refer to :obj:`numpy.unstack`.
3432+
3433+
Parameters
3434+
----------
3435+
x : {dpnp.ndarray, usm_ndarray}
3436+
The array to be unstacked.
3437+
axis : int, optional
3438+
Axis along which the array will be split.
3439+
Default: ``0``.
3440+
3441+
Returns
3442+
-------
3443+
unstacked : tuple of dpnp.ndarray
3444+
The unstacked arrays.
3445+
3446+
See Also
3447+
--------
3448+
:obj:`dpnp.stack` : Join a sequence of arrays along a new axis.
3449+
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
3450+
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
3451+
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
3452+
size.
3453+
3454+
Notes
3455+
-----
3456+
:obj:`dpnp.unstack` serves as the reverse operation of :obj:`dpnp.stack`,
3457+
i.e., ``dpnp.stack(dpnp.unstack(x, axis=axis), axis=axis) == x``.
3458+
3459+
This function is equivalent to ``tuple(dpnp.moveaxis(x, axis, 0))``, since
3460+
iterating on an array iterates along the first axis.
3461+
3462+
Examples
3463+
--------
3464+
>>> import dpnp as np
3465+
>>> arr = np.arange(24).reshape((2, 3, 4))
3466+
>>> np.unstack(arr)
3467+
(array([[ 0, 1, 2, 3],
3468+
[ 4, 5, 6, 7],
3469+
[ 8, 9, 10, 11]]),
3470+
array([[12, 13, 14, 15],
3471+
[16, 17, 18, 19],
3472+
[20, 21, 22, 23]]))
3473+
3474+
>>> np.unstack(arr, axis=1)
3475+
(array([[ 0, 1, 2, 3],
3476+
[12, 13, 14, 15]]),
3477+
array([[ 4, 5, 6, 7],
3478+
[16, 17, 18, 19]]),
3479+
array([[ 8, 9, 10, 11],
3480+
[20, 21, 22, 23]]))
3481+
3482+
>>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1)
3483+
>>> arr2.shape
3484+
(2, 3, 4)
3485+
>>> np.all(arr == arr2)
3486+
array(True)
3487+
3488+
"""
3489+
3490+
usm_x = dpnp.get_usm_ndarray(x)
3491+
3492+
if usm_x.ndim == 0:
3493+
raise ValueError("Input array must be at least 1-d.")
3494+
3495+
res = dpt.unstack(usm_x, axis=axis)
3496+
return tuple(dpnp_array._create_from_usm_ndarray(a) for a in res)
3497+
3498+
34163499
def vsplit(ary, indices_or_sections):
34173500
"""
34183501
Split an array into multiple sub-arrays vertically (row-wise).
@@ -3521,6 +3604,8 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
35213604
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
35223605
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
35233606
size.
3607+
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
3608+
an axis.
35243609
35253610
Examples
35263611
--------

tests/test_arraymanipulation.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,85 @@ def test_generator(self):
866866
dpnp.stack(map(lambda x: x, dpnp.ones((3, 2))))
867867

868868

869+
# numpy.unstack() is available since numpy >= 2.1
870+
@testing.with_requires("numpy>=2.1")
871+
class TestUnstack:
872+
def test_non_array_input(self):
873+
with pytest.raises(TypeError):
874+
dpnp.unstack(1)
875+
876+
@pytest.mark.parametrize(
877+
"input", [([1, 2, 3],), [dpnp.int32(1), dpnp.int32(2), dpnp.int32(3)]]
878+
)
879+
def test_scalar_input(self, input):
880+
with pytest.raises(TypeError):
881+
dpnp.unstack(input)
882+
883+
@pytest.mark.parametrize("dtype", get_all_dtypes())
884+
def test_0d_array_input(self, dtype):
885+
np_a = numpy.array(1, dtype=dtype)
886+
dp_a = dpnp.array(np_a, dtype=dtype)
887+
888+
with pytest.raises(ValueError):
889+
numpy.unstack(np_a)
890+
with pytest.raises(ValueError):
891+
dpnp.unstack(dp_a)
892+
893+
@pytest.mark.parametrize("dtype", get_all_dtypes())
894+
def test_1d_array(self, dtype):
895+
np_a = numpy.array([1, 2, 3], dtype=dtype)
896+
dp_a = dpnp.array(np_a, dtype=dtype)
897+
898+
np_res = numpy.unstack(np_a)
899+
dp_res = dpnp.unstack(dp_a)
900+
assert len(dp_res) == len(np_res)
901+
for dp_arr, np_arr in zip(dp_res, np_res):
902+
assert_array_equal(dp_arr.asnumpy(), np_arr)
903+
904+
@pytest.mark.parametrize("dtype", get_all_dtypes())
905+
def test_2d_array(self, dtype):
906+
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
907+
dp_a = dpnp.array(np_a, dtype=dtype)
908+
909+
np_res = numpy.unstack(np_a, axis=0)
910+
dp_res = dpnp.unstack(dp_a, axis=0)
911+
assert len(dp_res) == len(np_res)
912+
for dp_arr, np_arr in zip(dp_res, np_res):
913+
assert_array_equal(dp_arr.asnumpy(), np_arr)
914+
915+
@pytest.mark.parametrize("axis", [0, 1, -1])
916+
@pytest.mark.parametrize("dtype", get_all_dtypes())
917+
def test_2d_array_axis(self, axis, dtype):
918+
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
919+
dp_a = dpnp.array(np_a, dtype=dtype)
920+
921+
np_res = numpy.unstack(np_a, axis=axis)
922+
dp_res = dpnp.unstack(dp_a, axis=axis)
923+
assert len(dp_res) == len(np_res)
924+
for dp_arr, np_arr in zip(dp_res, np_res):
925+
assert_array_equal(dp_arr.asnumpy(), np_arr)
926+
927+
@pytest.mark.parametrize("axis", [2, -3])
928+
@pytest.mark.parametrize("dtype", get_all_dtypes())
929+
def test_invalid_axis(self, axis, dtype):
930+
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
931+
dp_a = dpnp.array(np_a, dtype=dtype)
932+
933+
with pytest.raises(AxisError):
934+
numpy.unstack(np_a, axis=axis)
935+
with pytest.raises(AxisError):
936+
dpnp.unstack(dp_a, axis=axis)
937+
938+
@pytest.mark.parametrize("dtype", get_all_dtypes())
939+
def test_empty_array(self, dtype):
940+
np_a = numpy.array([], dtype=dtype)
941+
dp_a = dpnp.array(np_a, dtype=dtype)
942+
943+
np_res = numpy.unstack(np_a)
944+
dp_res = dpnp.unstack(dp_a)
945+
assert len(dp_res) == len(np_res)
946+
947+
869948
class TestVstack:
870949
def test_non_iterable(self):
871950
assert_raises(TypeError, dpnp.vstack, 1)

0 commit comments

Comments
 (0)