|
77 | 77 | "solve",
|
78 | 78 | "svd",
|
79 | 79 | "slogdet",
|
| 80 | + "tensorsolve", |
80 | 81 | ]
|
81 | 82 |
|
82 | 83 |
|
@@ -897,3 +898,73 @@ def slogdet(a):
|
897 | 898 | check_stacked_square(a)
|
898 | 899 |
|
899 | 900 | return dpnp_slogdet(a)
|
| 901 | + |
| 902 | + |
| 903 | +def tensorsolve(a, b, axes=None): |
| 904 | + """ |
| 905 | + Solve the tensor equation ``a x = b`` for x. |
| 906 | +
|
| 907 | + For full documentation refer to :obj:`numpy.linalg.tensorsolve`. |
| 908 | +
|
| 909 | + Parameters |
| 910 | + ---------- |
| 911 | + a : {dpnp.ndarray, usm_ndarray} |
| 912 | + Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals |
| 913 | + the shape of that sub-tensor of `a` consisting of the appropriate |
| 914 | + number of its rightmost indices, and must be such that |
| 915 | + ``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be |
| 916 | + 'square'). |
| 917 | + b : {dpnp.ndarray, usm_ndarray} |
| 918 | + Right-hand tensor, which can be of any shape. |
| 919 | + axes : tuple of ints, optional |
| 920 | + Axes in `a` to reorder to the right, before inversion. |
| 921 | + If None , no reordering is done. |
| 922 | + Default: ``None``. |
| 923 | +
|
| 924 | + Returns |
| 925 | + ------- |
| 926 | + out : dpnp.ndarray |
| 927 | + The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``. |
| 928 | +
|
| 929 | + See Also |
| 930 | + -------- |
| 931 | + :obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes. |
| 932 | + :obj:`dpnp.linalg.tensorinv` : Compute the `inverse` of a tensor. |
| 933 | +
|
| 934 | + Examples |
| 935 | + -------- |
| 936 | + >>> import dpnp as np |
| 937 | + >>> a = np.eye(2*3*4) |
| 938 | + >>> a.shape = (2*3, 4, 2, 3, 4) |
| 939 | + >>> b = np.random.randn(2*3, 4) |
| 940 | + >>> x = np.linalg.tensorsolve(a, b) |
| 941 | + >>> x.shape |
| 942 | + (2, 3, 4) |
| 943 | + >>> np.allclose(np.tensordot(a, x, axes=3), b) |
| 944 | + array([ True]) |
| 945 | +
|
| 946 | + """ |
| 947 | + |
| 948 | + dpnp.check_supported_arrays_type(a, b) |
| 949 | + a_ndim = a.ndim |
| 950 | + |
| 951 | + if axes is not None: |
| 952 | + all_axes = list(range(a_ndim)) |
| 953 | + for k in axes: |
| 954 | + all_axes.remove(k) |
| 955 | + all_axes.insert(a_ndim, k) |
| 956 | + a = a.transpose(tuple(all_axes)) |
| 957 | + |
| 958 | + old_shape = a.shape[-(a_ndim - b.ndim) :] |
| 959 | + prod = numpy.prod(old_shape) |
| 960 | + |
| 961 | + if a.size != prod**2: |
| 962 | + raise dpnp.linalg.LinAlgError( |
| 963 | + "Input arrays must satisfy the requirement \ |
| 964 | + prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])" |
| 965 | + ) |
| 966 | + |
| 967 | + a = a.reshape(-1, prod) |
| 968 | + b = b.ravel() |
| 969 | + res = solve(a, b) |
| 970 | + return res.reshape(old_shape) |
0 commit comments