Skip to content

Commit 5fb2311

Browse files
Add matrix_transpose() for dpnp.linalg module
1 parent 7e0a38c commit 5fb2311

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"matmul",
7878
"matrix_power",
7979
"matrix_rank",
80+
"matrix_transpose",
8081
"multi_dot",
8182
"norm",
8283
"pinv",
@@ -871,6 +872,48 @@ def matrix_rank(A, tol=None, hermitian=False):
871872
return dpnp_matrix_rank(A, tol=tol, hermitian=hermitian)
872873

873874

875+
def matrix_transpose(x, /):
876+
"""
877+
Transposes a matrix (or a stack of matrices) ``x``.
878+
879+
For full documentation refer to :obj:`numpy.linalg.matrix_transpose`.
880+
881+
Parameters
882+
----------
883+
x : (..., M, N) {dpnp.ndarray, usm_ndarray}
884+
Input array with ``x.ndim >= 2` and whose two innermost
885+
dimensions form ``MxN`` matrices.
886+
887+
Returns
888+
-------
889+
out : dpnp.ndarray
890+
An array containing the transpose for each matrix and having shape
891+
(..., N, M).
892+
893+
See Also
894+
--------
895+
:obj:`dpnp.transpose` : Returns an array with axes transposed.
896+
897+
Examples
898+
--------
899+
>>> import dpnp as np
900+
>>> a = np.array([[1, 2], [3, 4]])
901+
>>> np.linalg.matrix_transpose(a)
902+
array([[1, 3],
903+
[2, 4]])
904+
905+
>>> b = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
906+
>>> np.linalg.matrix_transpose(b)
907+
array([[[1, 3],
908+
[2, 4]],
909+
[[5, 7],
910+
[6, 8]]])
911+
912+
"""
913+
914+
return dpnp.matrix_transpose(x)
915+
916+
874917
def multi_dot(arrays, *, out=None):
875918
"""
876919
Compute the dot product of two or more arrays in a single function call.

0 commit comments

Comments
 (0)