Skip to content

Commit a009ebd

Browse files
add dpnp.linalg.tensorsolve impl
1 parent b51100f commit a009ebd

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
"solve",
7878
"svd",
7979
"slogdet",
80+
"tensorsolve",
8081
]
8182

8283

@@ -897,3 +898,73 @@ def slogdet(a):
897898
check_stacked_square(a)
898899

899900
return dpnp_slogdet(a)
901+
902+
903+
def tensorsolve(a, b, axes=None):
904+
"""
905+
Solve the tensor equation ``a x = b`` for x.
906+
907+
For full documentation refer to :obj:`numpy.linalg.tensorsolve`.
908+
909+
Parameters
910+
----------
911+
a : {dpnp.ndarray, usm_ndarray}
912+
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
913+
the shape of that sub-tensor of `a` consisting of the appropriate
914+
number of its rightmost indices, and must be such that
915+
``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
916+
'square').
917+
b : {dpnp.ndarray, usm_ndarray}
918+
Right-hand tensor, which can be of any shape.
919+
axes : tuple of ints, optional
920+
Axes in `a` to reorder to the right, before inversion.
921+
If None , no reordering is done.
922+
Default: ``None``.
923+
924+
Returns
925+
-------
926+
out : dpnp.ndarray
927+
The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.
928+
929+
See Also
930+
--------
931+
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
932+
:obj:`dpnp.linalg.tensorinv` : Compute the `inverse` of a tensor.
933+
934+
Examples
935+
--------
936+
>>> import dpnp as np
937+
>>> a = np.eye(2*3*4)
938+
>>> a.shape = (2*3, 4, 2, 3, 4)
939+
>>> b = np.random.randn(2*3, 4)
940+
>>> x = np.linalg.tensorsolve(a, b)
941+
>>> x.shape
942+
(2, 3, 4)
943+
>>> np.allclose(np.tensordot(a, x, axes=3), b)
944+
array([ True])
945+
946+
"""
947+
948+
dpnp.check_supported_arrays_type(a, b)
949+
a_ndim = a.ndim
950+
951+
if axes is not None:
952+
all_axes = list(range(a_ndim))
953+
for k in axes:
954+
all_axes.remove(k)
955+
all_axes.insert(a_ndim, k)
956+
a = a.transpose(tuple(all_axes))
957+
958+
old_shape = a.shape[-(a_ndim - b.ndim) :]
959+
prod = numpy.prod(old_shape)
960+
961+
if a.size != prod**2:
962+
raise dpnp.linalg.LinAlgError(
963+
"Input arrays must satisfy the requirement \
964+
prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
965+
)
966+
967+
a = a.reshape(-1, prod)
968+
b = b.ravel()
969+
res = solve(a, b)
970+
return res.reshape(old_shape)

0 commit comments

Comments
 (0)