Skip to content

Commit a907757

Browse files
Merge 778d1ca into 367e74e
2 parents 367e74e + 778d1ca commit a907757

File tree

3 files changed

+70
-15
lines changed

3 files changed

+70
-15
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,12 +1612,12 @@ def solve(a, b):
16121612
----------
16131613
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
16141614
Coefficient matrix.
1615-
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
1615+
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
16161616
Ordinate or "dependent variable" values.
16171617
16181618
Returns
16191619
-------
1620-
out : {(, M,), (, M, K)} dpnp.ndarray
1620+
out : {(..., M,), (..., M, K)} dpnp.ndarray
16211621
Solution to the system `ax = b`. Returned shape is identical to `b`.
16221622
16231623
See Also
@@ -1644,14 +1644,38 @@ def solve(a, b):
16441644
assert_stacked_2d(a)
16451645
assert_stacked_square(a)
16461646

1647-
if not (
1648-
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
1649-
):
1650-
raise dpnp.linalg.LinAlgError(
1651-
"a must have (..., M, M) shape and b must have (..., M) "
1652-
"or (..., M, K)"
1647+
a_shape = a.shape
1648+
b_shape = b.shape
1649+
b_ndim = b.ndim
1650+
1651+
# compatible with numpy>=2.0
1652+
if b_ndim == 0:
1653+
raise ValueError("b must have at least one dimension")
1654+
if b_ndim == 1:
1655+
if a_shape[-1] != b.size:
1656+
raise ValueError(
1657+
"a must have (..., M, M) shape and b must have (M,) "
1658+
"for one-dimensional b"
1659+
)
1660+
b = dpnp.broadcast_to(b, a_shape[:-1])
1661+
return dpnp_solve(a, b)
1662+
1663+
if a_shape[-1] != b_shape[-2]:
1664+
raise ValueError(
1665+
"a must have (..., M, M) shape and b must have (..., M, K) shape"
16531666
)
16541667

1668+
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
1669+
broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2])
1670+
1671+
a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
1672+
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]
1673+
1674+
if a_shape != a_broadcasted_shape:
1675+
a = dpnp.broadcast_to(a, a_broadcasted_shape)
1676+
if b_shape != b_broadcasted_shape:
1677+
b = dpnp.broadcast_to(b, b_broadcasted_shape)
1678+
16551679
return dpnp_solve(a, b)
16561680

16571681

dpnp/tests/test_linalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,36 @@ def test_solve(self, dtype):
26942694

26952695
assert_allclose(expected, result, rtol=1e-06)
26962696

2697+
@testing.with_requires("numpy>=2.0")
2698+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2699+
@pytest.mark.parametrize(
2700+
"a_shape, b_shape",
2701+
[
2702+
((4, 4), (2, 2, 4, 3)),
2703+
((2, 5, 5), (1, 5, 3)),
2704+
((2, 4, 4), (2, 2, 4, 2)),
2705+
((3, 2, 2), (3, 1, 2, 1)),
2706+
((2, 2, 2, 2, 2), (2,)),
2707+
((2, 2, 2, 2, 2), (2, 3)),
2708+
],
2709+
)
2710+
def test_solve_broadcast(self, a_shape, b_shape, dtype):
2711+
# Set seed_value=81 to prevent
2712+
# random generation of the input singular matrix
2713+
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)
2714+
2715+
# Set seed_value=76 to prevent
2716+
# random generation of the input singular matrix
2717+
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)
2718+
2719+
a_dp = inp.array(a_np)
2720+
b_dp = inp.array(b_np)
2721+
2722+
expected = numpy.linalg.solve(a_np, b_np)
2723+
result = inp.linalg.solve(a_dp, b_dp)
2724+
2725+
assert_dtype_allclose(result, expected)
2726+
26972727
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
26982728
def test_solve_nrhs_greater_n(self, dtype):
26992729
# Test checking the case when nrhs > n for

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ def test_solve(self):
6060
# for other cases this signature must be followed
6161
# (..., m, m), (..., m, n) -> (..., m, n)
6262
# https://github.com/numpy/numpy/pull/25914
63-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
64-
self.check_x((2, 4, 4), (2, 4))
65-
self.check_x((2, 3, 2, 2), (2, 3, 2))
66-
self.check_x((0, 2, 2), (0, 2))
63+
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
64+
self.check_x((2, 3, 3), (3,))
65+
self.check_x((2, 5, 3, 3), (3,))
6766

6867
def check_shape(self, a_shape, b_shape, error_types):
6968
for xp, error_type in error_types.items():
@@ -96,11 +95,13 @@ def test_invalid_shape(self):
9695
self.check_shape((3, 3), (2,), value_errors)
9796
self.check_shape((3, 3), (2, 2), value_errors)
9897
self.check_shape((3, 3, 4), (3,), linalg_errors)
99-
# Since numpy >= 2.0, this case does not raise an error
100-
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101-
self.check_shape((2, 3, 3), (3,), value_errors)
10298
self.check_shape((3, 3), (0,), value_errors)
10399
self.check_shape((0, 3, 4), (3,), linalg_errors)
100+
# Not allowed since numpy 2.0
101+
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
102+
self.check_shape((0, 2, 2), (0, 2), value_errors)
103+
self.check_shape((2, 4, 4), (2, 4), value_errors)
104+
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)
104105

105106

106107
@testing.parameterize(

0 commit comments

Comments
 (0)