Skip to content

Commit e5a6d37

Browse files
authored
Merge branch 'master' into impl-searchsorted
2 parents e9a21f7 + 7ca1aff commit e5a6d37

File tree

5 files changed

+339
-0
lines changed

5 files changed

+339
-0
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
"solve",
7979
"svd",
8080
"slogdet",
81+
"tensorinv",
82+
"tensorsolve",
8183
]
8284

8385

@@ -930,3 +932,132 @@ def slogdet(a):
930932
check_stacked_square(a)
931933

932934
return dpnp_slogdet(a)
935+
936+
937+
def tensorinv(a, ind=2):
938+
"""
939+
Compute the 'inverse' of an N-dimensional array.
940+
941+
For full documentation refer to :obj:`numpy.linalg.tensorinv`.
942+
943+
Parameters
944+
----------
945+
a : {dpnp.ndarray, usm_ndarray}
946+
Tensor to `invert`. Its shape must be 'square', i. e.,
947+
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
948+
ind : int, optional
949+
Number of first indices that are involved in the inverse sum.
950+
Must be a positive integer.
951+
Default: 2.
952+
953+
Returns
954+
-------
955+
out : dpnp.ndarray
956+
The inverse of a tensor whose shape is equivalent to
957+
``a.shape[ind:] + a.shape[:ind]``.
958+
959+
See Also
960+
--------
961+
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
962+
:obj:`dpnp.linalg.tensorsolve` : Solve the tensor equation ``a x = b`` for x.
963+
964+
Examples
965+
--------
966+
>>> import dpnp as np
967+
>>> a = np.eye(4*6)
968+
>>> a.shape = (4, 6, 8, 3)
969+
>>> ainv = np.linalg.tensorinv(a, ind=2)
970+
>>> ainv.shape
971+
(8, 3, 4, 6)
972+
973+
>>> a = np.eye(4*6)
974+
>>> a.shape = (24, 8, 3)
975+
>>> ainv = np.linalg.tensorinv(a, ind=1)
976+
>>> ainv.shape
977+
(8, 3, 24)
978+
979+
"""
980+
981+
dpnp.check_supported_arrays_type(a)
982+
983+
if ind <= 0:
984+
raise ValueError("Invalid ind argument")
985+
986+
old_shape = a.shape
987+
inv_shape = old_shape[ind:] + old_shape[:ind]
988+
prod = numpy.prod(old_shape[ind:])
989+
a = a.reshape(prod, -1)
990+
a_inv = inv(a)
991+
992+
return a_inv.reshape(*inv_shape)
993+
994+
995+
def tensorsolve(a, b, axes=None):
996+
"""
997+
Solve the tensor equation ``a x = b`` for x.
998+
999+
For full documentation refer to :obj:`numpy.linalg.tensorsolve`.
1000+
1001+
Parameters
1002+
----------
1003+
a : {dpnp.ndarray, usm_ndarray}
1004+
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
1005+
the shape of that sub-tensor of `a` consisting of the appropriate
1006+
number of its rightmost indices, and must be such that
1007+
``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
1008+
'square').
1009+
b : {dpnp.ndarray, usm_ndarray}
1010+
Right-hand tensor, which can be of any shape.
1011+
axes : tuple of ints, optional
1012+
Axes in `a` to reorder to the right, before inversion.
1013+
If ``None`` , no reordering is done.
1014+
Default: ``None``.
1015+
1016+
Returns
1017+
-------
1018+
out : dpnp.ndarray
1019+
The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.
1020+
1021+
See Also
1022+
--------
1023+
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
1024+
:obj:`dpnp.linalg.tensorinv` : Compute the 'inverse' of an N-dimensional array.
1025+
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
1026+
1027+
Examples
1028+
--------
1029+
>>> import dpnp as np
1030+
>>> a = np.eye(2*3*4)
1031+
>>> a.shape = (2*3, 4, 2, 3, 4)
1032+
>>> b = np.random.randn(2*3, 4)
1033+
>>> x = np.linalg.tensorsolve(a, b)
1034+
>>> x.shape
1035+
(2, 3, 4)
1036+
>>> np.allclose(np.tensordot(a, x, axes=3), b)
1037+
array([ True])
1038+
1039+
"""
1040+
1041+
dpnp.check_supported_arrays_type(a, b)
1042+
a_ndim = a.ndim
1043+
1044+
if axes is not None:
1045+
all_axes = list(range(a_ndim))
1046+
for k in axes:
1047+
all_axes.remove(k)
1048+
all_axes.insert(a_ndim, k)
1049+
a = a.transpose(tuple(all_axes))
1050+
1051+
old_shape = a.shape[-(a_ndim - b.ndim) :]
1052+
prod = numpy.prod(old_shape)
1053+
1054+
if a.size != prod**2:
1055+
raise dpnp.linalg.LinAlgError(
1056+
"Input arrays must satisfy the requirement \
1057+
prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
1058+
)
1059+
1060+
a = a.reshape(-1, prod)
1061+
b = b.ravel()
1062+
res = solve(a, b)
1063+
return res.reshape(old_shape)

tests/test_linalg.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,3 +1459,86 @@ def test_pinv_errors(self):
14591459
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
14601460
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
14611461
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)
1462+
1463+
1464+
class TestTensorinv:
1465+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1466+
@pytest.mark.parametrize(
1467+
"shape, ind",
1468+
[
1469+
((4, 6, 8, 3), 2),
1470+
((24, 8, 3), 1),
1471+
],
1472+
ids=[
1473+
"(4, 6, 8, 3)",
1474+
"(24, 8, 3)",
1475+
],
1476+
)
1477+
def test_tensorinv(self, dtype, shape, ind):
1478+
a = numpy.eye(24, dtype=dtype).reshape(shape)
1479+
a_dp = inp.array(a)
1480+
1481+
ainv = numpy.linalg.tensorinv(a, ind=ind)
1482+
ainv_dp = inp.linalg.tensorinv(a_dp, ind=ind)
1483+
1484+
assert ainv.shape == ainv_dp.shape
1485+
assert_dtype_allclose(ainv_dp, ainv)
1486+
1487+
def test_test_tensorinv_errors(self):
1488+
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
1489+
1490+
# unsupported type `a`
1491+
a_np = inp.asnumpy(a_dp)
1492+
assert_raises(TypeError, inp.linalg.pinv, a_np)
1493+
1494+
# unsupported type `ind`
1495+
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, 2.0)
1496+
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, [2.0])
1497+
assert_raises(ValueError, inp.linalg.tensorinv, a_dp, -1)
1498+
1499+
# non-square
1500+
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)
1501+
1502+
1503+
class TestTensorsolve:
1504+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1505+
@pytest.mark.parametrize(
1506+
"axes",
1507+
[None, (1,), (2,)],
1508+
ids=[
1509+
"None",
1510+
"(1,)",
1511+
"(2,)",
1512+
],
1513+
)
1514+
def test_tensorsolve_axes(self, dtype, axes):
1515+
a = numpy.eye(12).reshape(12, 3, 4).astype(dtype)
1516+
b = numpy.ones(a.shape[0], dtype=dtype)
1517+
1518+
a_dp = inp.array(a)
1519+
b_dp = inp.array(b)
1520+
1521+
res_np = numpy.linalg.tensorsolve(a, b, axes=axes)
1522+
res_dp = inp.linalg.tensorsolve(a_dp, b_dp, axes=axes)
1523+
1524+
assert res_np.shape == res_dp.shape
1525+
assert_dtype_allclose(res_dp, res_np)
1526+
1527+
def test_tensorsolve_errors(self):
1528+
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
1529+
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")
1530+
1531+
# unsupported type `a` and `b`
1532+
a_np = inp.asnumpy(a_dp)
1533+
b_np = inp.asnumpy(b_dp)
1534+
assert_raises(TypeError, inp.linalg.tensorsolve, a_np, b_dp)
1535+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, b_np)
1536+
1537+
# unsupported type `axes`
1538+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, 2.0)
1539+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, -2)
1540+
1541+
# incorrect axes
1542+
assert_raises(
1543+
inp.linalg.LinAlgError, inp.linalg.tensorsolve, a_dp, b_dp, (1,)
1544+
)

tests/test_sycl_queue.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,3 +1896,42 @@ def test_pinv(shape, hermitian, rcond_as_array, device):
18961896
B_queue = B_result.sycl_queue
18971897

18981898
assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)
1899+
1900+
1901+
@pytest.mark.parametrize(
1902+
"device",
1903+
valid_devices,
1904+
ids=[device.filter_string for device in valid_devices],
1905+
)
1906+
def test_tensorinv(device):
1907+
a_np = numpy.eye(12).reshape(12, 4, 3)
1908+
a_dp = dpnp.array(a_np, device=device)
1909+
1910+
result = dpnp.linalg.tensorinv(a_dp, ind=1)
1911+
expected = numpy.linalg.tensorinv(a_np, ind=1)
1912+
assert_dtype_allclose(result, expected)
1913+
1914+
result_queue = result.sycl_queue
1915+
1916+
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
1917+
1918+
1919+
@pytest.mark.parametrize(
1920+
"device",
1921+
valid_devices,
1922+
ids=[device.filter_string for device in valid_devices],
1923+
)
1924+
def test_tensorsolve(device):
1925+
a_np = numpy.random.randn(3, 2, 6).astype(dpnp.default_float_type())
1926+
b_np = numpy.ones(a_np.shape[:2], dtype=a_np.dtype)
1927+
1928+
a_dp = dpnp.array(a_np, device=device)
1929+
b_dp = dpnp.array(b_np, device=device)
1930+
1931+
result = dpnp.linalg.tensorsolve(a_dp, b_dp)
1932+
expected = numpy.linalg.tensorsolve(a_np, b_np)
1933+
assert_dtype_allclose(result, expected)
1934+
1935+
result_queue = result.sycl_queue
1936+
1937+
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)

tests/test_usm_type.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,3 +1041,25 @@ def test_qr(shape, mode, usm_type):
10411041

10421042
assert a.usm_type == dp_q.usm_type
10431043
assert a.usm_type == dp_r.usm_type
1044+
1045+
1046+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1047+
def test_tensorinv(usm_type):
1048+
a = dp.eye(12, usm_type=usm_type).reshape(12, 4, 3)
1049+
ainv = dp.linalg.tensorinv(a, ind=1)
1050+
1051+
assert a.usm_type == ainv.usm_type
1052+
1053+
1054+
@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
1055+
@pytest.mark.parametrize("usm_type_b", list_of_usm_types, ids=list_of_usm_types)
1056+
def test_tensorsolve(usm_type_a, usm_type_b):
1057+
data = numpy.random.randn(3, 2, 6)
1058+
a = dp.array(data, usm_type=usm_type_a)
1059+
b = dp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)
1060+
1061+
result = dp.linalg.tensorsolve(a, b)
1062+
1063+
assert a.usm_type == usm_type_a
1064+
assert b.usm_type == usm_type_b
1065+
assert result.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_b])

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def test_invalid_shape(self):
9595
self.check_shape((0, 3, 4), (3,), linalg_errors)
9696

9797

98+
@testing.parameterize(
99+
*testing.product(
100+
{
101+
"a_shape": [(2, 3, 6), (3, 4, 4, 3)],
102+
"axes": [None, (0, 2)],
103+
}
104+
)
105+
)
106+
@testing.fix_random()
107+
class TestTensorSolve(unittest.TestCase):
108+
@testing.for_dtypes("ifdFD")
109+
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
110+
def test_tensorsolve(self, xp, dtype):
111+
a_shape = self.a_shape
112+
b_shape = self.a_shape[:2]
113+
a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0)
114+
b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1)
115+
return xp.linalg.tensorsolve(a, b, axes=self.axes)
116+
117+
98118
@testing.parameterize(
99119
*testing.product(
100120
{
@@ -208,3 +228,47 @@ def test_pinv_size_0(self):
208228
self.check_x((0, 0), rcond=1e-15)
209229
self.check_x((0, 2, 3), rcond=1e-15)
210230
self.check_x((2, 0, 3), rcond=1e-15)
231+
232+
233+
class TestTensorInv(unittest.TestCase):
234+
@testing.for_dtypes("ifdFD")
235+
@_condition.retry(10)
236+
def check_x(self, a_shape, ind, dtype):
237+
a_cpu = numpy.random.randint(0, 10, size=a_shape).astype(dtype)
238+
a_gpu = cupy.asarray(a_cpu)
239+
a_gpu_copy = a_gpu.copy()
240+
result_cpu = numpy.linalg.tensorinv(a_cpu, ind=ind)
241+
result_gpu = cupy.linalg.tensorinv(a_gpu, ind=ind)
242+
assert_dtype_allclose(result_gpu, result_cpu)
243+
testing.assert_array_equal(a_gpu_copy, a_gpu)
244+
245+
def check_shape(self, a_shape, ind):
246+
a = cupy.random.rand(*a_shape)
247+
with self.assertRaises(
248+
(numpy.linalg.LinAlgError, cupy.linalg.LinAlgError)
249+
):
250+
cupy.linalg.tensorinv(a, ind=ind)
251+
252+
def check_ind(self, a_shape, ind):
253+
a = cupy.random.rand(*a_shape)
254+
with self.assertRaises(ValueError):
255+
cupy.linalg.tensorinv(a, ind=ind)
256+
257+
def test_tensorinv(self):
258+
self.check_x((12, 3, 4), ind=1)
259+
self.check_x((3, 8, 24), ind=2)
260+
self.check_x((18, 3, 3, 2), ind=1)
261+
self.check_x((1, 4, 2, 2), ind=2)
262+
self.check_x((2, 3, 5, 30), ind=3)
263+
self.check_x((24, 2, 2, 3, 2), ind=1)
264+
self.check_x((3, 4, 2, 3, 2), ind=2)
265+
self.check_x((1, 2, 3, 2, 3), ind=3)
266+
self.check_x((3, 2, 1, 2, 12), ind=4)
267+
268+
def test_invalid_shape(self):
269+
self.check_shape((2, 3, 4), ind=1)
270+
self.check_shape((1, 2, 3, 4), ind=3)
271+
272+
def test_invalid_index(self):
273+
self.check_ind((12, 3, 4), ind=-1)
274+
self.check_ind((18, 3, 3, 2), ind=0)

0 commit comments

Comments
 (0)