Skip to content

Add dpnp.linalg.tensorsolve() implementation #1753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"svd",
"slogdet",
"tensorinv",
"tensorsolve",
]


Expand Down Expand Up @@ -935,7 +936,7 @@ def slogdet(a):

def tensorinv(a, ind=2):
"""
Compute the `inverse` of a tensor.
Compute the 'inverse' of an N-dimensional array.

For full documentation refer to :obj:`numpy.linalg.tensorinv`.

Expand All @@ -944,7 +945,7 @@ def tensorinv(a, ind=2):
a : {dpnp.ndarray, usm_ndarray}
Tensor to `invert`. Its shape must be 'square', i. e.,
``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
ind : int
ind : int, optional
Number of first indices that are involved in the inverse sum.
Must be a positive integer.
Default: 2.
Expand Down Expand Up @@ -989,3 +990,74 @@ def tensorinv(a, ind=2):
a_inv = inv(a)

return a_inv.reshape(*inv_shape)


def tensorsolve(a, b, axes=None):
"""
Solve the tensor equation ``a x = b`` for x.

For full documentation refer to :obj:`numpy.linalg.tensorsolve`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
the shape of that sub-tensor of `a` consisting of the appropriate
number of its rightmost indices, and must be such that
``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
'square').
b : {dpnp.ndarray, usm_ndarray}
Right-hand tensor, which can be of any shape.
axes : tuple of ints, optional
Axes in `a` to reorder to the right, before inversion.
If ``None`` , no reordering is done.
Default: ``None``.

Returns
-------
out : dpnp.ndarray
The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.

See Also
--------
:obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
:obj:`dpnp.linalg.tensorinv` : Compute the 'inverse' of an N-dimensional array.
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.

Examples
--------
>>> import dpnp as np
>>> a = np.eye(2*3*4)
>>> a.shape = (2*3, 4, 2, 3, 4)
>>> b = np.random.randn(2*3, 4)
>>> x = np.linalg.tensorsolve(a, b)
>>> x.shape
(2, 3, 4)
>>> np.allclose(np.tensordot(a, x, axes=3), b)
array([ True])

"""

dpnp.check_supported_arrays_type(a, b)
a_ndim = a.ndim

if axes is not None:
all_axes = list(range(a_ndim))
for k in axes:
all_axes.remove(k)
all_axes.insert(a_ndim, k)
a = a.transpose(tuple(all_axes))

old_shape = a.shape[-(a_ndim - b.ndim) :]
prod = numpy.prod(old_shape)

if a.size != prod**2:
raise dpnp.linalg.LinAlgError(
"Input arrays must satisfy the requirement \
prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)

a = a.reshape(-1, prod)
b = b.ravel()
res = solve(a, b)
return res.reshape(old_shape)
44 changes: 44 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,3 +1498,47 @@ def test_test_tensorinv_errors(self):

# non-square
assert_raises(inp.linalg.LinAlgError, inp.linalg.tensorinv, a_dp, 1)


class TestTensorsolve:
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"axes",
[None, (1,), (2,)],
ids=[
"None",
"(1,)",
"(2,)",
],
)
def test_tensorsolve_axes(self, dtype, axes):
a = numpy.eye(12).reshape(12, 3, 4).astype(dtype)
b = numpy.ones(a.shape[0], dtype=dtype)

a_dp = inp.array(a)
b_dp = inp.array(b)

res_np = numpy.linalg.tensorsolve(a, b, axes=axes)
res_dp = inp.linalg.tensorsolve(a_dp, b_dp, axes=axes)

assert res_np.shape == res_dp.shape
assert_dtype_allclose(res_dp, res_np)

def test_tensorsolve_errors(self):
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")

# unsupported type `a` and `b`
a_np = inp.asnumpy(a_dp)
b_np = inp.asnumpy(b_dp)
assert_raises(TypeError, inp.linalg.tensorsolve, a_np, b_dp)
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, b_np)

# unsupported type `axes`
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, 2.0)
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, -2)

# incorrect axes
assert_raises(
inp.linalg.LinAlgError, inp.linalg.tensorsolve, a_dp, b_dp, (1,)
)
21 changes: 21 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,3 +1891,24 @@ def test_tensorinv(device):
result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_tensorsolve(device):
a_np = numpy.random.randn(3, 2, 6).astype(dpnp.default_float_type())
b_np = numpy.ones(a_np.shape[:2], dtype=a_np.dtype)

a_dp = dpnp.array(a_np, device=device)
b_dp = dpnp.array(b_np, device=device)

result = dpnp.linalg.tensorsolve(a_dp, b_dp)
expected = numpy.linalg.tensorsolve(a_np, b_np)
assert_dtype_allclose(result, expected)

result_queue = result.sycl_queue

assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)
14 changes: 14 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,3 +1035,17 @@ def test_tensorinv(usm_type):
ainv = dp.linalg.tensorinv(a, ind=1)

assert a.usm_type == ainv.usm_type


@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_b", list_of_usm_types, ids=list_of_usm_types)
def test_tensorsolve(usm_type_a, usm_type_b):
data = numpy.random.randn(3, 2, 6)
a = dp.array(data, usm_type=usm_type_a)
b = dp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)

result = dp.linalg.tensorsolve(a, b)

assert a.usm_type == usm_type_a
assert b.usm_type == usm_type_b
assert result.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_b])
20 changes: 20 additions & 0 deletions tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def test_invalid_shape(self):
self.check_shape((0, 3, 4), (3,), linalg_errors)


@testing.parameterize(
*testing.product(
{
"a_shape": [(2, 3, 6), (3, 4, 4, 3)],
"axes": [None, (0, 2)],
}
)
)
@testing.fix_random()
class TestTensorSolve(unittest.TestCase):
@testing.for_dtypes("ifdFD")
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
def test_tensorsolve(self, xp, dtype):
a_shape = self.a_shape
b_shape = self.a_shape[:2]
a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0)
b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1)
return xp.linalg.tensorsolve(a, b, axes=self.axes)


@testing.parameterize(
*testing.product(
{
Expand Down