Skip to content

Commit b737d1c

Browse files
Add TestMatrixtranspose to dpnp tests
1 parent 4b640de commit b737d1c

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/test_arraymanipulation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,50 @@ def test_one_element(self):
556556
assert_array_equal(res, a)
557557

558558

559+
# numpy.matrix_transpose() is available since numpy >= 2.0
560+
@testing.with_requires("numpy>=2.0")
561+
class TestMatrixtranspose:
562+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
563+
@pytest.mark.parametrize(
564+
"shape",
565+
[(3, 5), (4, 2), (2, 5, 2), (2, 3, 3, 6)],
566+
ids=["(3,5)", "(4,2)", "(2,5,2)", "(2,3,3,6)"],
567+
)
568+
def test_matrix_transpose(self, dtype, shape):
569+
a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape)
570+
dp_a = dpnp.array(a)
571+
572+
expected = numpy.matrix_transpose(a)
573+
result = dpnp.matrix_transpose(dp_a)
574+
575+
assert_allclose(result, expected)
576+
577+
@pytest.mark.parametrize(
578+
"shape",
579+
[(0, 0), (1, 0, 0), (0, 2, 2), (0, 1, 0, 4)],
580+
ids=["(0,0)", "(1,0,0)", "(0,2,2)", "(0, 1, 0, 4)"],
581+
)
582+
def test_matrix_transpose_empty(self, shape):
583+
a = numpy.empty(shape, dtype=dpnp.default_float_type())
584+
dp_a = dpnp.array(a)
585+
586+
expected = numpy.matrix_transpose(a)
587+
result = dpnp.matrix_transpose(dp_a)
588+
589+
assert_allclose(result, expected)
590+
591+
def test_matrix_transpose_errors(self):
592+
a_dp = dpnp.array([[1, 2], [3, 4]], dtype="float32")
593+
594+
# unsupported type
595+
a_np = dpnp.asnumpy(a_dp)
596+
assert_raises(TypeError, dpnp.matrix_transpose, a_np)
597+
598+
# a.ndim < 2
599+
a_dp_ndim_1 = a_dp.flatten()
600+
assert_raises(ValueError, dpnp.matrix_transpose, a_dp_ndim_1)
601+
602+
559603
class TestRollaxis:
560604
data = [
561605
(0, 0),

0 commit comments

Comments
 (0)