Skip to content

Commit e5d3127

Browse files
Add dpnp.linalg.tensorinv() implementation (#1752)
* Add a new dpnp.linalg.tensorinv impl * Add tests for tensorinv --------- Co-authored-by: Anton <[email protected]>
1 parent e44469c commit e5d3127

File tree

5 files changed

+168
-0
lines changed

5 files changed

+168
-0
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"solve",
7979
"svd",
8080
"slogdet",
81+
"tensorinv",
8182
]
8283

8384

@@ -930,3 +931,61 @@ def slogdet(a):
930931
check_stacked_square(a)
931932

932933
return dpnp_slogdet(a)
934+
935+
936+
def tensorinv(a, ind=2):
937+
"""
938+
Compute the `inverse` of a tensor.
939+
940+
For full documentation refer to :obj:`numpy.linalg.tensorinv`.
941+
942+
Parameters
943+
----------
944+
a : {dpnp.ndarray, usm_ndarray}
945+
Tensor to `invert`. Its shape must be 'square', i. e.,
946+
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
947+
ind : int
948+
Number of first indices that are involved in the inverse sum.
949+
Must be a positive integer.
950+
Default: 2.
951+
952+
Returns
953+
-------
954+
out : dpnp.ndarray
955+
The inverse of a tensor whose shape is equivalent to
956+
``a.shape[ind:] + a.shape[:ind]``.
957+
958+
See Also
959+
--------
960+
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
961+
:obj:`dpnp.linalg.tensorsolve` : Solve the tensor equation ``a x = b`` for x.
962+
963+
Examples
964+
--------
965+
>>> import dpnp as np
966+
>>> a = np.eye(4*6)
967+
>>> a.shape = (4, 6, 8, 3)
968+
>>> ainv = np.linalg.tensorinv(a, ind=2)
969+
>>> ainv.shape
970+
(8, 3, 4, 6)
971+
972+
>>> a = np.eye(4*6)
973+
>>> a.shape = (24, 8, 3)
974+
>>> ainv = np.linalg.tensorinv(a, ind=1)
975+
>>> ainv.shape
976+
(8, 3, 24)
977+
978+
"""
979+
980+
dpnp.check_supported_arrays_type(a)
981+
982+
if ind <= 0:
983+
raise ValueError("Invalid ind argument")
984+
985+
old_shape = a.shape
986+
inv_shape = old_shape[ind:] + old_shape[:ind]
987+
prod = numpy.prod(old_shape[ind:])
988+
a = a.reshape(prod, -1)
989+
a_inv = inv(a)
990+
991+
return a_inv.reshape(*inv_shape)

tests/test_linalg.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,3 +1459,42 @@ def test_pinv_errors(self):
14591459
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
14601460
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
14611461
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)
1462+
1463+
1464+
class TestTensorinv:
1465+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1466+
@pytest.mark.parametrize(
1467+
"shape, ind",
1468+
[
1469+
((4, 6, 8, 3), 2),
1470+
((24, 8, 3), 1),
1471+
],
1472+
ids=[
1473+
"(4, 6, 8, 3)",
1474+
"(24, 8, 3)",
1475+
],
1476+
)
1477+
def test_tensorinv(self, dtype, shape, ind):
1478+
a = numpy.eye(24, dtype=dtype).reshape(shape)
1479+
a_dp = inp.array(a)
1480+
1481+
ainv = numpy.linalg.tensorinv(a, ind=ind)
1482+
ainv_dp = inp.linalg.tensorinv(a_dp, ind=ind)
1483+
1484+
assert ainv.shape == ainv_dp.shape
1485+
assert_dtype_allclose(ainv_dp, ainv)
1486+
1487+
def test_test_tensorinv_errors(self):
1488+
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
1489+
1490+
# unsupported type `a`
1491+
a_np = inp.asnumpy(a_dp)
1492+
assert_raises(TypeError, inp.linalg.pinv, a_np)
1493+
1494+
# unsupported type `ind`
1495+
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, 2.0)
1496+
assert_raises(TypeError, inp.linalg.tensorinv, a_dp, [2.0])
1497+
assert_raises(ValueError, inp.linalg.tensorinv, a_dp, -1)
1498+
1499+
# non-square
1500+
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)

tests/test_sycl_queue.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,3 +1873,21 @@ def test_pinv(shape, hermitian, rcond_as_array, device):
18731873
B_queue = B_result.sycl_queue
18741874

18751875
assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)
1876+
1877+
1878+
@pytest.mark.parametrize(
1879+
"device",
1880+
valid_devices,
1881+
ids=[device.filter_string for device in valid_devices],
1882+
)
1883+
def test_tensorinv(device):
1884+
a_np = numpy.eye(12).reshape(12, 4, 3)
1885+
a_dp = dpnp.array(a_np, device=device)
1886+
1887+
result = dpnp.linalg.tensorinv(a_dp, ind=1)
1888+
expected = numpy.linalg.tensorinv(a_np, ind=1)
1889+
assert_dtype_allclose(result, expected)
1890+
1891+
result_queue = result.sycl_queue
1892+
1893+
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)

tests/test_usm_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,3 +1027,11 @@ def test_qr(shape, mode, usm_type):
10271027

10281028
assert a.usm_type == dp_q.usm_type
10291029
assert a.usm_type == dp_r.usm_type
1030+
1031+
1032+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1033+
def test_tensorinv(usm_type):
1034+
a = dp.eye(12, usm_type=usm_type).reshape(12, 4, 3)
1035+
ainv = dp.linalg.tensorinv(a, ind=1)
1036+
1037+
assert a.usm_type == ainv.usm_type

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,47 @@ def test_pinv_size_0(self):
208208
self.check_x((0, 0), rcond=1e-15)
209209
self.check_x((0, 2, 3), rcond=1e-15)
210210
self.check_x((2, 0, 3), rcond=1e-15)
211+
212+
213+
class TestTensorInv(unittest.TestCase):
214+
@testing.for_dtypes("ifdFD")
215+
@_condition.retry(10)
216+
def check_x(self, a_shape, ind, dtype):
217+
a_cpu = numpy.random.randint(0, 10, size=a_shape).astype(dtype)
218+
a_gpu = cupy.asarray(a_cpu)
219+
a_gpu_copy = a_gpu.copy()
220+
result_cpu = numpy.linalg.tensorinv(a_cpu, ind=ind)
221+
result_gpu = cupy.linalg.tensorinv(a_gpu, ind=ind)
222+
assert_dtype_allclose(result_gpu, result_cpu)
223+
testing.assert_array_equal(a_gpu_copy, a_gpu)
224+
225+
def check_shape(self, a_shape, ind):
226+
a = cupy.random.rand(*a_shape)
227+
with self.assertRaises(
228+
(numpy.linalg.LinAlgError, cupy.linalg.LinAlgError)
229+
):
230+
cupy.linalg.tensorinv(a, ind=ind)
231+
232+
def check_ind(self, a_shape, ind):
233+
a = cupy.random.rand(*a_shape)
234+
with self.assertRaises(ValueError):
235+
cupy.linalg.tensorinv(a, ind=ind)
236+
237+
def test_tensorinv(self):
238+
self.check_x((12, 3, 4), ind=1)
239+
self.check_x((3, 8, 24), ind=2)
240+
self.check_x((18, 3, 3, 2), ind=1)
241+
self.check_x((1, 4, 2, 2), ind=2)
242+
self.check_x((2, 3, 5, 30), ind=3)
243+
self.check_x((24, 2, 2, 3, 2), ind=1)
244+
self.check_x((3, 4, 2, 3, 2), ind=2)
245+
self.check_x((1, 2, 3, 2, 3), ind=3)
246+
self.check_x((3, 2, 1, 2, 12), ind=4)
247+
248+
def test_invalid_shape(self):
249+
self.check_shape((2, 3, 4), ind=1)
250+
self.check_shape((1, 2, 3, 4), ind=3)
251+
252+
def test_invalid_index(self):
253+
self.check_ind((12, 3, 4), ind=-1)
254+
self.check_ind((18, 3, 3, 2), ind=0)

0 commit comments

Comments
 (0)