Skip to content

Support usm_ndarray batched input for dpnp.linalg #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def tensorinv(a, ind=2):
old_shape = a.shape
inv_shape = old_shape[ind:] + old_shape[:ind]
prod = numpy.prod(old_shape[ind:])
a = a.reshape(prod, -1)
a = dpnp.reshape(a, (prod, -1))
a_inv = inv(a)

return a_inv.reshape(*inv_shape)
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def tensorsolve(a, b, axes=None):
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)

a = a.reshape(-1, prod)
b = b.ravel()
a = dpnp.reshape(a, (-1, prod))
b = dpnp.ravel(b)
res = solve(a, b)
return res.reshape(old_shape)
20 changes: 10 additions & 10 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
is_cpu_device = a.sycl_device.has_aspect_cpu
orig_shape = a.shape
# get 3d input array by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
a_usm_arr = dpnp.get_usm_ndarray(a)

# allocate a memory for dpnp array of eigenvalues
Expand Down Expand Up @@ -191,7 +191,7 @@ def _batched_inv(a, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)
a_sycl_queue = a.sycl_queue
Expand Down Expand Up @@ -280,11 +280,11 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type):
if a.ndim > 3:
# get 3d input arrays by reshape
if a.ndim == b.ndim:
b = b.reshape(-1, b_shape[-2], b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1]))
else:
b = b.reshape(-1, b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-1]))

a = a.reshape(-1, a_shape[-2], a_shape[-1])
a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1]))

a_usm_arr = dpnp.get_usm_ndarray(a)
b_usm_arr = dpnp.get_usm_ndarray(b)
Expand Down Expand Up @@ -386,7 +386,7 @@ def _batched_qr(a, mode="reduced"):
a_sycl_queue = a.sycl_queue

# get 3d input arrays by reshape
a = a.reshape(-1, m, n)
a = dpnp.reshape(a, (-1, m, n))

a = a.swapaxes(-2, -1)
a_usm_arr = dpnp.get_usm_ndarray(a)
Expand Down Expand Up @@ -537,7 +537,7 @@ def _batched_svd(

if a.ndim > 3:
# get 3d input arrays by reshape
a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1])
a = dpnp.reshape(a, (prod(a.shape[:-2]), a.shape[-2], a.shape[-1]))
reshape = True

batch_size = a.shape[0]
Expand Down Expand Up @@ -830,7 +830,7 @@ def _lu_factor(a, res_type):
if a.ndim > 2:
orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -1743,7 +1743,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -2171,7 +2171,7 @@ def dpnp_matrix_power(a, n):
# `result` will hold the final matrix power,
# while `acc` serves as an accumulator for the intermediate matrix powers.
result = None
acc = a.copy()
acc = dpnp.copy(a)
while n > 0:
n, bit = divmod(n, 2)
if bit:
Expand Down
66 changes: 66 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,72 @@ def vvsort(val, vec, size, xp):
vec[:, imax] = temp


@pytest.mark.parametrize(
"func, gen_kwargs, func_kwargs",
[
pytest.param("cholesky", {"hermitian": True}, {}),
pytest.param("cond", {}, {}),
pytest.param("det", {}, {}),
pytest.param("eig", {}, {}),
pytest.param("eigh", {"hermitian": True}, {}),
pytest.param("eigvals", {}, {}),
pytest.param("eigvalsh", {"hermitian": True}, {}),
pytest.param("inv", {}, {}),
pytest.param("matrix_power", {}, {"n": 4}),
pytest.param("matrix_rank", {}, {}),
pytest.param("norm", {}, {}),
pytest.param("pinv", {}, {}),
pytest.param("qr", {}, {}),
pytest.param("slogdet", {}, {}),
pytest.param("solve", {}, {}),
pytest.param("svd", {}, {}),
pytest.param("tensorinv", {}, {"ind": 1}),
pytest.param("tensorsolve", {}, {}),
],
)
def test_usm_ndarray_input_batch(func, gen_kwargs, func_kwargs):
shape = (
(2, 2, 3, 3) if func not in ["tensorinv", "tensorsolve"] else (4, 2, 2)
)

if func in ["lstsq", "solve", "tensorsolve"]:
if func == "tensorsolve":
shape_b = (4,)
dpt_args = [
dpt.asarray(
generate_random_numpy_array(
shape, seed_value=81, **gen_kwargs
)
),
dpt.asarray(
generate_random_numpy_array(
shape_b, seed_value=81, **gen_kwargs
)
),
]
else:
dpt_args = [
dpt.asarray(
generate_random_numpy_array(
shape, seed_value=81, **gen_kwargs
)
)
for _ in range(2)
]
else:
dpt_args = [
dpt.asarray(generate_random_numpy_array(shape, **gen_kwargs))
]

result = getattr(inp.linalg, func)(*dpt_args, **func_kwargs)

if isinstance(result, tuple):
for res in result:
assert isinstance(res, inp.ndarray)
else:
assert isinstance(result, inp.ndarray)


class TestCholesky:
@pytest.mark.parametrize(
"array",
Expand Down
Loading