Skip to content

Commit 96a2e41

Browse files
Support usm_ndarray batched input for dpnp.linalg (#1880)
* Add usm_ndarray input support for linalg * Add test_usm_ndarray_input_batch to test_linalg.py * Add usm_ndarray input support for dpnp_iface_linearalgebra * Add test_usm_ndarray_linearalgebra_batch to test_linalg.py * Apply comments --------- Co-authored-by: Anton <[email protected]>
1 parent a813fae commit 96a2e41

File tree

4 files changed

+100
-19
lines changed

4 files changed

+100
-19
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -892,13 +892,13 @@ def outer(a, b, out=None):
892892
dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False)
893893
if dpnp.isscalar(a):
894894
x1 = a
895-
x2 = b.ravel()[None, :]
895+
x2 = dpnp.ravel(b)[None, :]
896896
elif dpnp.isscalar(b):
897-
x1 = a.ravel()[:, None]
897+
x1 = dpnp.ravel(a)[:, None]
898898
x2 = b
899899
else:
900-
x1 = a.ravel()
901-
x2 = b.ravel()
900+
x1 = dpnp.ravel(a)
901+
x2 = dpnp.ravel(b)
902902

903903
return dpnp.multiply.outer(x1, x2, out=out)
904904

@@ -1056,8 +1056,8 @@ def tensordot(a, b, axes=2):
10561056
newshape_b = (n1, n2)
10571057
oldb = [b_shape[axis] for axis in notin]
10581058

1059-
at = a.transpose(newaxes_a).reshape(newshape_a)
1060-
bt = b.transpose(newaxes_b).reshape(newshape_b)
1059+
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
1060+
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
10611061
res = dpnp.matmul(at, bt)
10621062

10631063
return res.reshape(olda + oldb)

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ def tensorinv(a, ind=2):
13541354
old_shape = a.shape
13551355
inv_shape = old_shape[ind:] + old_shape[:ind]
13561356
prod = numpy.prod(old_shape[ind:])
1357-
a = a.reshape(prod, -1)
1357+
a = dpnp.reshape(a, (prod, -1))
13581358
a_inv = inv(a)
13591359

13601360
return a_inv.reshape(*inv_shape)
@@ -1428,7 +1428,7 @@ def tensorsolve(a, b, axes=None):
14281428
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
14291429
)
14301430

1431-
a = a.reshape(-1, prod)
1432-
b = b.ravel()
1431+
a = dpnp.reshape(a, (-1, prod))
1432+
b = dpnp.ravel(b)
14331433
res = solve(a, b)
14341434
return res.reshape(old_shape)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
9999
is_cpu_device = a.sycl_device.has_aspect_cpu
100100
orig_shape = a.shape
101101
# get 3d input array by reshape
102-
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
102+
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
103103
a_usm_arr = dpnp.get_usm_ndarray(a)
104104

105105
# allocate a memory for dpnp array of eigenvalues
@@ -191,7 +191,7 @@ def _batched_inv(a, res_type):
191191

192192
orig_shape = a.shape
193193
# get 3d input arrays by reshape
194-
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
194+
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
195195
batch_size = a.shape[0]
196196
a_usm_arr = dpnp.get_usm_ndarray(a)
197197
a_sycl_queue = a.sycl_queue
@@ -280,11 +280,11 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type):
280280
if a.ndim > 3:
281281
# get 3d input arrays by reshape
282282
if a.ndim == b.ndim:
283-
b = b.reshape(-1, b_shape[-2], b_shape[-1])
283+
b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1]))
284284
else:
285-
b = b.reshape(-1, b_shape[-1])
285+
b = dpnp.reshape(b, (-1, b_shape[-1]))
286286

287-
a = a.reshape(-1, a_shape[-2], a_shape[-1])
287+
a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1]))
288288

289289
a_usm_arr = dpnp.get_usm_ndarray(a)
290290
b_usm_arr = dpnp.get_usm_ndarray(b)
@@ -386,7 +386,7 @@ def _batched_qr(a, mode="reduced"):
386386
a_sycl_queue = a.sycl_queue
387387

388388
# get 3d input arrays by reshape
389-
a = a.reshape(-1, m, n)
389+
a = dpnp.reshape(a, (-1, m, n))
390390

391391
a = a.swapaxes(-2, -1)
392392
a_usm_arr = dpnp.get_usm_ndarray(a)
@@ -537,7 +537,7 @@ def _batched_svd(
537537

538538
if a.ndim > 3:
539539
# get 3d input arrays by reshape
540-
a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1])
540+
a = dpnp.reshape(a, (prod(a.shape[:-2]), a.shape[-2], a.shape[-1]))
541541
reshape = True
542542

543543
batch_size = a.shape[0]
@@ -830,7 +830,7 @@ def _lu_factor(a, res_type):
830830
if a.ndim > 2:
831831
orig_shape = a.shape
832832
# get 3d input arrays by reshape
833-
a = a.reshape(-1, n, n)
833+
a = dpnp.reshape(a, (-1, n, n))
834834
batch_size = a.shape[0]
835835
a_usm_arr = dpnp.get_usm_ndarray(a)
836836

@@ -1743,7 +1743,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):
17431743

17441744
orig_shape = a.shape
17451745
# get 3d input arrays by reshape
1746-
a = a.reshape(-1, n, n)
1746+
a = dpnp.reshape(a, (-1, n, n))
17471747
batch_size = a.shape[0]
17481748
a_usm_arr = dpnp.get_usm_ndarray(a)
17491749

@@ -2171,7 +2171,7 @@ def dpnp_matrix_power(a, n):
21712171
# `result` will hold the final matrix power,
21722172
# while `acc` serves as an accumulator for the intermediate matrix powers.
21732173
result = None
2174-
acc = a.copy()
2174+
acc = dpnp.copy(a)
21752175
while n > 0:
21762176
n, bit = divmod(n, 2)
21772177
if bit:

tests/test_linalg.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,87 @@ def vvsort(val, vec, size, xp):
5757
vec[:, imax] = temp
5858

5959

60+
# check linear algebra functions from dpnp.linalg
61+
# with multidimensional usm_ndarray as input
62+
@pytest.mark.parametrize(
63+
"func, gen_kwargs, func_kwargs",
64+
[
65+
pytest.param("cholesky", {"hermitian": True}, {}),
66+
pytest.param("cond", {}, {}),
67+
pytest.param("det", {}, {}),
68+
pytest.param("eig", {}, {}),
69+
pytest.param("eigh", {"hermitian": True}, {}),
70+
pytest.param("eigvals", {}, {}),
71+
pytest.param("eigvalsh", {"hermitian": True}, {}),
72+
pytest.param("inv", {}, {}),
73+
pytest.param("matrix_power", {}, {"n": 4}),
74+
pytest.param("matrix_rank", {}, {}),
75+
pytest.param("norm", {}, {}),
76+
pytest.param("pinv", {}, {}),
77+
pytest.param("qr", {}, {}),
78+
pytest.param("slogdet", {}, {}),
79+
pytest.param("solve", {}, {}),
80+
pytest.param("svd", {}, {}),
81+
pytest.param("tensorinv", {}, {"ind": 1}),
82+
pytest.param("tensorsolve", {}, {}),
83+
],
84+
)
85+
def test_usm_ndarray_linalg_batch(func, gen_kwargs, func_kwargs):
86+
shape = (
87+
(2, 2, 3, 3) if func not in ["tensorinv", "tensorsolve"] else (4, 2, 2)
88+
)
89+
90+
if func == "tensorsolve":
91+
shape_b = (4,)
92+
dpt_args = [
93+
dpt.asarray(
94+
generate_random_numpy_array(shape, seed_value=81, **gen_kwargs)
95+
),
96+
dpt.asarray(
97+
generate_random_numpy_array(
98+
shape_b, seed_value=81, **gen_kwargs
99+
)
100+
),
101+
]
102+
elif func in ["lstsq", "solve"]:
103+
dpt_args = [
104+
dpt.asarray(
105+
generate_random_numpy_array(shape, seed_value=81, **gen_kwargs)
106+
)
107+
for _ in range(2)
108+
]
109+
else:
110+
dpt_args = [
111+
dpt.asarray(generate_random_numpy_array(shape, **gen_kwargs))
112+
]
113+
114+
result = getattr(inp.linalg, func)(*dpt_args, **func_kwargs)
115+
116+
if isinstance(result, tuple):
117+
for res in result:
118+
assert isinstance(res, inp.ndarray)
119+
else:
120+
assert isinstance(result, inp.ndarray)
121+
122+
123+
# check linear algebra functions from dpnp
124+
# with multidimensional usm_ndarray as input
125+
@pytest.mark.parametrize(
126+
"func", ["dot", "inner", "kron", "matmul", "outer", "tensordot", "vdot"]
127+
)
128+
def test_usm_ndarray_linearalgebra_batch(func):
129+
shape = (2, 2, 2, 2)
130+
131+
dpt_args = [
132+
dpt.asarray(generate_random_numpy_array(shape, seed_value=81))
133+
for _ in range(2)
134+
]
135+
136+
result = getattr(inp, func)(*dpt_args)
137+
138+
assert isinstance(result, inp.ndarray)
139+
140+
60141
class TestCholesky:
61142
@pytest.mark.parametrize(
62143
"array",

0 commit comments

Comments
 (0)