Skip to content

Commit 7ca1aff

Browse files
Add dpnp.linalg.tensorsolve() implementation (#1753)
* add dpnp.linalg.tensorsolve impl * Add tests for tensorsolve * Add test_tensorsolve_axes * Address remarks * Address remarks for #1752
1 parent e5d3127 commit 7ca1aff

File tree

5 files changed

+173
-2
lines changed

5 files changed

+173
-2
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"svd",
8080
"slogdet",
8181
"tensorinv",
82+
"tensorsolve",
8283
]
8384

8485

@@ -935,7 +936,7 @@ def slogdet(a):
935936

936937
def tensorinv(a, ind=2):
937938
"""
938-
Compute the `inverse` of a tensor.
939+
Compute the 'inverse' of an N-dimensional array.
939940
940941
For full documentation refer to :obj:`numpy.linalg.tensorinv`.
941942
@@ -944,7 +945,7 @@ def tensorinv(a, ind=2):
944945
a : {dpnp.ndarray, usm_ndarray}
945946
Tensor to `invert`. Its shape must be 'square', i. e.,
946947
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
947-
ind : int
948+
ind : int, optional
948949
Number of first indices that are involved in the inverse sum.
949950
Must be a positive integer.
950951
Default: 2.
@@ -989,3 +990,74 @@ def tensorinv(a, ind=2):
989990
a_inv = inv(a)
990991

991992
return a_inv.reshape(*inv_shape)
993+
994+
995+
def tensorsolve(a, b, axes=None):
996+
"""
997+
Solve the tensor equation ``a x = b`` for x.
998+
999+
For full documentation refer to :obj:`numpy.linalg.tensorsolve`.
1000+
1001+
Parameters
1002+
----------
1003+
a : {dpnp.ndarray, usm_ndarray}
1004+
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
1005+
the shape of that sub-tensor of `a` consisting of the appropriate
1006+
number of its rightmost indices, and must be such that
1007+
``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
1008+
'square').
1009+
b : {dpnp.ndarray, usm_ndarray}
1010+
Right-hand tensor, which can be of any shape.
1011+
axes : tuple of ints, optional
1012+
Axes in `a` to reorder to the right, before inversion.
1013+
If ``None`` , no reordering is done.
1014+
Default: ``None``.
1015+
1016+
Returns
1017+
-------
1018+
out : dpnp.ndarray
1019+
The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.
1020+
1021+
See Also
1022+
--------
1023+
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
1024+
:obj:`dpnp.linalg.tensorinv` : Compute the 'inverse' of an N-dimensional array.
1025+
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
1026+
1027+
Examples
1028+
--------
1029+
>>> import dpnp as np
1030+
>>> a = np.eye(2*3*4)
1031+
>>> a.shape = (2*3, 4, 2, 3, 4)
1032+
>>> b = np.random.randn(2*3, 4)
1033+
>>> x = np.linalg.tensorsolve(a, b)
1034+
>>> x.shape
1035+
(2, 3, 4)
1036+
>>> np.allclose(np.tensordot(a, x, axes=3), b)
1037+
array([ True])
1038+
1039+
"""
1040+
1041+
dpnp.check_supported_arrays_type(a, b)
1042+
a_ndim = a.ndim
1043+
1044+
if axes is not None:
1045+
all_axes = list(range(a_ndim))
1046+
for k in axes:
1047+
all_axes.remove(k)
1048+
all_axes.insert(a_ndim, k)
1049+
a = a.transpose(tuple(all_axes))
1050+
1051+
old_shape = a.shape[-(a_ndim - b.ndim) :]
1052+
prod = numpy.prod(old_shape)
1053+
1054+
if a.size != prod**2:
1055+
raise dpnp.linalg.LinAlgError(
1056+
"Input arrays must satisfy the requirement \
1057+
prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
1058+
)
1059+
1060+
a = a.reshape(-1, prod)
1061+
b = b.ravel()
1062+
res = solve(a, b)
1063+
return res.reshape(old_shape)

tests/test_linalg.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,3 +1498,47 @@ def test_test_tensorinv_errors(self):
14981498

14991499
# non-square
15001500
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)
1501+
1502+
1503+
class TestTensorsolve:
1504+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1505+
@pytest.mark.parametrize(
1506+
"axes",
1507+
[None, (1,), (2,)],
1508+
ids=[
1509+
"None",
1510+
"(1,)",
1511+
"(2,)",
1512+
],
1513+
)
1514+
def test_tensorsolve_axes(self, dtype, axes):
1515+
a = numpy.eye(12).reshape(12, 3, 4).astype(dtype)
1516+
b = numpy.ones(a.shape[0], dtype=dtype)
1517+
1518+
a_dp = inp.array(a)
1519+
b_dp = inp.array(b)
1520+
1521+
res_np = numpy.linalg.tensorsolve(a, b, axes=axes)
1522+
res_dp = inp.linalg.tensorsolve(a_dp, b_dp, axes=axes)
1523+
1524+
assert res_np.shape == res_dp.shape
1525+
assert_dtype_allclose(res_dp, res_np)
1526+
1527+
def test_tensorsolve_errors(self):
1528+
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
1529+
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")
1530+
1531+
# unsupported type `a` and `b`
1532+
a_np = inp.asnumpy(a_dp)
1533+
b_np = inp.asnumpy(b_dp)
1534+
assert_raises(TypeError, inp.linalg.tensorsolve, a_np, b_dp)
1535+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, b_np)
1536+
1537+
# unsupported type `axes`
1538+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, 2.0)
1539+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, -2)
1540+
1541+
# incorrect axes
1542+
assert_raises(
1543+
inp.linalg.LinAlgError, inp.linalg.tensorsolve, a_dp, b_dp, (1,)
1544+
)

tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,3 +1891,24 @@ def test_tensorinv(device):
18911891
result_queue = result.sycl_queue
18921892

18931893
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
1894+
1895+
1896+
@pytest.mark.parametrize(
1897+
"device",
1898+
valid_devices,
1899+
ids=[device.filter_string for device in valid_devices],
1900+
)
1901+
def test_tensorsolve(device):
1902+
a_np = numpy.random.randn(3, 2, 6).astype(dpnp.default_float_type())
1903+
b_np = numpy.ones(a_np.shape[:2], dtype=a_np.dtype)
1904+
1905+
a_dp = dpnp.array(a_np, device=device)
1906+
b_dp = dpnp.array(b_np, device=device)
1907+
1908+
result = dpnp.linalg.tensorsolve(a_dp, b_dp)
1909+
expected = numpy.linalg.tensorsolve(a_np, b_np)
1910+
assert_dtype_allclose(result, expected)
1911+
1912+
result_queue = result.sycl_queue
1913+
1914+
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)

tests/test_usm_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,3 +1035,17 @@ def test_tensorinv(usm_type):
10351035
ainv = dp.linalg.tensorinv(a, ind=1)
10361036

10371037
assert a.usm_type == ainv.usm_type
1038+
1039+
1040+
@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
1041+
@pytest.mark.parametrize("usm_type_b", list_of_usm_types, ids=list_of_usm_types)
1042+
def test_tensorsolve(usm_type_a, usm_type_b):
1043+
data = numpy.random.randn(3, 2, 6)
1044+
a = dp.array(data, usm_type=usm_type_a)
1045+
b = dp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)
1046+
1047+
result = dp.linalg.tensorsolve(a, b)
1048+
1049+
assert a.usm_type == usm_type_a
1050+
assert b.usm_type == usm_type_b
1051+
assert result.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_b])

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def test_invalid_shape(self):
9595
self.check_shape((0, 3, 4), (3,), linalg_errors)
9696

9797

98+
@testing.parameterize(
99+
*testing.product(
100+
{
101+
"a_shape": [(2, 3, 6), (3, 4, 4, 3)],
102+
"axes": [None, (0, 2)],
103+
}
104+
)
105+
)
106+
@testing.fix_random()
107+
class TestTensorSolve(unittest.TestCase):
108+
@testing.for_dtypes("ifdFD")
109+
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
110+
def test_tensorsolve(self, xp, dtype):
111+
a_shape = self.a_shape
112+
b_shape = self.a_shape[:2]
113+
a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0)
114+
b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1)
115+
return xp.linalg.tensorsolve(a, b, axes=self.axes)
116+
117+
98118
@testing.parameterize(
99119
*testing.product(
100120
{

0 commit comments

Comments
 (0)