Skip to content

Commit 4e02172

Browse files
Update dpnp.linalg.solve() to align NumPy 2.0 (#2198)
* Update solve with broadcasting to align numpy 2.0 * Update and add more tests for solve() * Keep only solve() logic for numpy 2.0 compatibility * Update cupy tests for solve() * Align TestSolve with cupy tests * Cover case b.ndim==0 * Add notes for solve()
1 parent 88911fb commit 4e02172

File tree

5 files changed

+109
-52
lines changed

5 files changed

+109
-52
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,18 +1612,24 @@ def solve(a, b):
16121612
----------
16131613
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
16141614
Coefficient matrix.
1615-
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
1615+
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
16161616
Ordinate or "dependent variable" values.
16171617
16181618
Returns
16191619
-------
1620-
out : {(, M,), (, M, K)} dpnp.ndarray
1620+
out : {(..., M,), (..., M, K)} dpnp.ndarray
16211621
Solution to the system `ax = b`. Returned shape is identical to `b`.
16221622
16231623
See Also
16241624
--------
16251625
:obj:`dpnp.dot` : Returns the dot product of two arrays.
16261626
1627+
Notes
1628+
-----
1629+
The `b` array is only treated as a shape (M,) column vector if it is
1630+
exactly 1-dimensional. In all other instances it is treated as a stack
1631+
of (M, K) matrices.
1632+
16271633
Examples
16281634
--------
16291635
>>> import dpnp as dp
@@ -1644,14 +1650,38 @@ def solve(a, b):
16441650
assert_stacked_2d(a)
16451651
assert_stacked_square(a)
16461652

1647-
if not (
1648-
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
1649-
):
1650-
raise dpnp.linalg.LinAlgError(
1651-
"a must have (..., M, M) shape and b must have (..., M) "
1652-
"or (..., M, K)"
1653+
a_shape = a.shape
1654+
b_shape = b.shape
1655+
b_ndim = b.ndim
1656+
1657+
# compatible with numpy>=2.0
1658+
if b_ndim == 0:
1659+
raise ValueError("b must have at least one dimension")
1660+
if b_ndim == 1:
1661+
if a_shape[-1] != b.size:
1662+
raise ValueError(
1663+
"a must have (..., M, M) shape and b must have (M,) "
1664+
"for one-dimensional b"
1665+
)
1666+
b = dpnp.broadcast_to(b, a_shape[:-1])
1667+
return dpnp_solve(a, b)
1668+
1669+
if a_shape[-1] != b_shape[-2]:
1670+
raise ValueError(
1671+
"a must have (..., M, M) shape and b must have (..., M, K) shape"
16531672
)
16541673

1674+
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
1675+
broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2])
1676+
1677+
a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
1678+
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]
1679+
1680+
if a_shape != a_broadcasted_shape:
1681+
a = dpnp.broadcast_to(a, a_broadcasted_shape)
1682+
if b_shape != b_broadcasted_shape:
1683+
b = dpnp.broadcast_to(b, b_broadcasted_shape)
1684+
16551685
return dpnp_solve(a, b)
16561686

16571687

dpnp/tests/test_linalg.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,36 @@ def test_solve(self, dtype):
26942694

26952695
assert_allclose(expected, result, rtol=1e-06)
26962696

2697+
@testing.with_requires("numpy>=2.0")
2698+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2699+
@pytest.mark.parametrize(
2700+
"a_shape, b_shape",
2701+
[
2702+
((4, 4), (2, 2, 4, 3)),
2703+
((2, 5, 5), (1, 5, 3)),
2704+
((2, 4, 4), (2, 2, 4, 2)),
2705+
((3, 2, 2), (3, 1, 2, 1)),
2706+
((2, 2, 2, 2, 2), (2,)),
2707+
((2, 2, 2, 2, 2), (2, 3)),
2708+
],
2709+
)
2710+
def test_solve_broadcast(self, a_shape, b_shape, dtype):
2711+
# Set seed_value=81 to prevent
2712+
# random generation of the input singular matrix
2713+
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)
2714+
2715+
# Set seed_value=76 to prevent
2716+
# random generation of the input singular matrix
2717+
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)
2718+
2719+
a_dp = inp.array(a_np)
2720+
b_dp = inp.array(b_np)
2721+
2722+
expected = numpy.linalg.solve(a_np, b_np)
2723+
result = inp.linalg.solve(a_dp, b_dp)
2724+
2725+
assert_dtype_allclose(result, expected)
2726+
26972727
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
26982728
def test_solve_nrhs_greater_n(self, dtype):
26992729
# Test checking the case when nrhs > n for
@@ -2800,6 +2830,10 @@ def test_solve_errors(self):
28002830
inp.linalg.LinAlgError, inp.linalg.solve, a_dp_ndim_1, b_dp
28012831
)
28022832

2833+
# b.ndim == 0
2834+
b_dp_ndim_0 = inp.array(2)
2835+
assert_raises(ValueError, inp.linalg.solve, a_dp, b_dp_ndim_0)
2836+
28032837

28042838
class TestSlogdet:
28052839
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))

dpnp/tests/test_sycl_queue.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,40 +2392,34 @@ def test_where(device):
23922392
ids=[device.filter_string for device in valid_devices],
23932393
)
23942394
@pytest.mark.parametrize(
2395-
"matrix, vector",
2395+
"matrix, rhs",
23962396
[
23972397
([[1, 2], [3, 5]], numpy.empty((2, 0))),
23982398
([[1, 2], [3, 5]], [1, 2]),
23992399
(
24002400
[
2401-
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
2402-
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
2403-
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
2401+
[[1, 1], [0, 2]],
2402+
[[3, -1], [1, 2]],
2403+
],
2404+
[
2405+
[[6, -4], [9, -6]],
2406+
[[15, 1], [15, 1]],
24042407
],
2405-
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
24062408
),
24072409
],
24082410
ids=[
2409-
"2D_Matrix_Empty_Vector",
2410-
"2D_Matrix_1D_Vector",
2411-
"3D_Matrix_and_Vectors",
2411+
"2D_Matrix_Empty_RHS",
2412+
"2D_Matrix_1D_RHS",
2413+
"3D_Matrix_and_3D_RHS",
24122414
],
24132415
)
2414-
def test_solve(matrix, vector, device):
2416+
def test_solve(matrix, rhs, device):
24152417
a_np = numpy.array(matrix)
2416-
b_np = numpy.array(vector)
2418+
b_np = numpy.array(rhs)
24172419

24182420
a_dp = dpnp.array(a_np, device=device)
24192421
b_dp = dpnp.array(b_np, device=device)
24202422

2421-
# In numpy 2.0 the broadcast ambiguity has been removed and now
2422-
# b is treaded as a single vector if and only if it is 1-dimensional;
2423-
# for other cases this signature must be followed
2424-
# (..., m, m), (..., m, n) -> (..., m, n)
2425-
# https://github.com/numpy/numpy/pull/25914
2426-
if a_dp.ndim > 2 and numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
2427-
pytest.skip("SAT-6928")
2428-
24292423
result = dpnp.linalg.solve(a_dp, b_dp)
24302424
expected = numpy.linalg.solve(a_np, b_np)
24312425
assert_dtype_allclose(result, expected)

dpnp/tests/test_usm_type.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,37 +1285,39 @@ def test_fftshift(self, func, usm_type):
12851285
"usm_type_matrix", list_of_usm_types, ids=list_of_usm_types
12861286
)
12871287
@pytest.mark.parametrize(
1288-
"usm_type_vector", list_of_usm_types, ids=list_of_usm_types
1288+
"usm_type_rhs", list_of_usm_types, ids=list_of_usm_types
12891289
)
12901290
@pytest.mark.parametrize(
1291-
"matrix, vector",
1291+
"matrix, rhs",
12921292
[
1293-
([[1, 2], [3, 5]], dp.empty((2, 0))),
1293+
([[1, 2], [3, 5]], numpy.empty((2, 0))),
12941294
([[1, 2], [3, 5]], [1, 2]),
12951295
(
12961296
[
1297-
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
1298-
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
1299-
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
1297+
[[1, 1], [0, 2]],
1298+
[[3, -1], [1, 2]],
1299+
],
1300+
[
1301+
[[6, -4], [9, -6]],
1302+
[[15, 1], [15, 1]],
13001303
],
1301-
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
13021304
),
13031305
],
13041306
ids=[
1305-
"2D_Matrix_Empty_Vector",
1306-
"2D_Matrix_1D_Vector",
1307-
"3D_Matrix_and_Vectors",
1307+
"2D_Matrix_Empty_RHS",
1308+
"2D_Matrix_1D_RHS",
1309+
"3D_Matrix_and_3D_RHS",
13081310
],
13091311
)
1310-
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
1312+
def test_solve(matrix, rhs, usm_type_matrix, usm_type_rhs):
13111313
x = dp.array(matrix, usm_type=usm_type_matrix)
1312-
y = dp.array(vector, usm_type=usm_type_vector)
1314+
y = dp.array(rhs, usm_type=usm_type_rhs)
13131315
z = dp.linalg.solve(x, y)
13141316

13151317
assert x.usm_type == usm_type_matrix
1316-
assert y.usm_type == usm_type_vector
1318+
assert y.usm_type == usm_type_rhs
13171319
assert z.usm_type == du.get_coerced_usm_type(
1318-
[usm_type_matrix, usm_type_vector]
1320+
[usm_type_matrix, usm_type_rhs]
13191321
)
13201322

13211323

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
4747
testing.assert_array_equal(b_copy, b)
4848
return result
4949

50+
@testing.with_requires("numpy>=2.0")
5051
def test_solve(self):
5152
self.check_x((4, 4), (4,))
5253
self.check_x((5, 5), (5, 2))
@@ -55,15 +56,9 @@ def test_solve(self):
5556
self.check_x((0, 0), (0,))
5657
self.check_x((0, 0), (0, 2))
5758
self.check_x((0, 2, 2), (0, 2, 3))
58-
# In numpy 2.0 the broadcast ambiguity has been removed and now
59-
# b is treaded as a single vector if and only if it is 1-dimensional;
60-
# for other cases this signature must be followed
61-
# (..., m, m), (..., m, n) -> (..., m, n)
62-
# https://github.com/numpy/numpy/pull/25914
63-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
64-
self.check_x((2, 4, 4), (2, 4))
65-
self.check_x((2, 3, 2, 2), (2, 3, 2))
66-
self.check_x((0, 2, 2), (0, 2))
59+
# Allowed since numpy 2
60+
self.check_x((2, 3, 3), (3,))
61+
self.check_x((2, 5, 3, 3), (3,))
6762

6863
def check_shape(self, a_shape, b_shape, error_types):
6964
for xp, error_type in error_types.items():
@@ -82,6 +77,7 @@ def test_solve_singular_empty(self, xp):
8277
# LinAlgError("Singular matrix") is not raised
8378
return xp.linalg.solve(a, b)
8479

80+
@testing.with_requires("numpy>=2.0")
8581
def test_invalid_shape(self):
8682
linalg_errors = {
8783
numpy: numpy.linalg.LinAlgError,
@@ -96,11 +92,12 @@ def test_invalid_shape(self):
9692
self.check_shape((3, 3), (2,), value_errors)
9793
self.check_shape((3, 3), (2, 2), value_errors)
9894
self.check_shape((3, 3, 4), (3,), linalg_errors)
99-
# Since numpy >= 2.0, this case does not raise an error
100-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101-
self.check_shape((2, 3, 3), (3,), value_errors)
10295
self.check_shape((3, 3), (0,), value_errors)
10396
self.check_shape((0, 3, 4), (3,), linalg_errors)
97+
# Not allowed since numpy 2.0
98+
self.check_shape((0, 2, 2), (0, 2), value_errors)
99+
self.check_shape((2, 4, 4), (2, 4), value_errors)
100+
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)
104101

105102

106103
@testing.parameterize(

0 commit comments

Comments
 (0)