Skip to content

Commit 7e0a38c

Browse files
Add tests for .mT arrribute
1 parent 0c5db41 commit 7e0a38c

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/test_ndarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numpy.testing import assert_allclose, assert_array_equal
55

66
import dpnp
7+
from tests.third_party.cupy import testing
78

89
from .helper import (
910
get_all_dtypes,
@@ -258,6 +259,38 @@ def test_array_as_index(shape, index_dtype):
258259
assert a[tuple(ind_arr)] == a[1]
259260

260261

262+
# numpy.matrix_transpose() is available since numpy >= 2.0
263+
@testing.with_requires("numpy>=2.0")
264+
@pytest.mark.parametrize(
265+
"shape",
266+
[(3, 5), (2, 5, 2), (2, 3, 3, 6)],
267+
ids=["(3,5)", "(2,5,2)", "(2,3,3,6)"],
268+
)
269+
def test_matrix_transpose(shape):
270+
a = numpy.arange(numpy.prod(shape)).reshape(shape)
271+
dp_a = dpnp.array(a)
272+
273+
expected = a.mT
274+
result = dp_a.mT
275+
276+
assert_allclose(result, expected)
277+
278+
# result is a view of dp_a:
279+
# changing result, modifies dp_a
280+
first_elem = (0,) * dp_a.ndim
281+
282+
result[first_elem] = -1.0
283+
assert dp_a[first_elem] == -1.0
284+
285+
286+
@testing.with_requires("numpy>=2.0")
287+
def test_matrix_transpose_error():
288+
# 1D array
289+
dp_a = dpnp.arange(6)
290+
with pytest.raises(ValueError):
291+
dp_a.mT
292+
293+
261294
def test_ravel():
262295
a = dpnp.ones((2, 2))
263296
b = a.ravel()

0 commit comments

Comments
 (0)