Skip to content

Commit 2ebe7eb

Browse files
Update solve with broadcasting to align numpy 2.0
1 parent cd23361 commit 2ebe7eb

File tree

1 file changed

+47
-8
lines changed

1 file changed

+47
-8
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,12 +1612,12 @@ def solve(a, b):
16121612
----------
16131613
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
16141614
Coefficient matrix.
1615-
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
1615+
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
16161616
Ordinate or "dependent variable" values.
16171617
16181618
Returns
16191619
-------
1620-
out : {(, M,), (, M, K)} dpnp.ndarray
1620+
out : {(..., M,), (..., M, K)} dpnp.ndarray
16211621
Solution to the system `ax = b`. Returned shape is identical to `b`.
16221622
16231623
See Also
@@ -1644,14 +1644,53 @@ def solve(a, b):
16441644
assert_stacked_2d(a)
16451645
assert_stacked_square(a)
16461646

1647-
if not (
1648-
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
1649-
):
1650-
raise dpnp.linalg.LinAlgError(
1651-
"a must have (..., M, M) shape and b must have (..., M) "
1652-
"or (..., M, K)"
1647+
a_ndim = a.ndim
1648+
b_ndim = b.ndim
1649+
1650+
a_shape = a.shape
1651+
b_shape = b.shape
1652+
1653+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
1654+
if not (
1655+
a_ndim in [b_ndim, b_ndim + 1]
1656+
and a_shape[:-1] == b_shape[: a_ndim - 1]
1657+
):
1658+
raise dpnp.linalg.LinAlgError(
1659+
"a must have (..., M, M) shape and b must have (..., M) "
1660+
"or (..., M, K)"
1661+
)
1662+
1663+
else: # compatible with numpy>=2.0
1664+
if b_ndim == 0:
1665+
raise ValueError("b must have at least one dimension")
1666+
if b_ndim == 1:
1667+
if a_shape[-1] != b.size:
1668+
raise ValueError(
1669+
"a must have (..., M, M) shape and b must have (M,) "
1670+
"for one-dimensional b"
1671+
)
1672+
b = dpnp.broadcast_to(b, a_shape[:-1])
1673+
return dpnp_solve(a, b)
1674+
1675+
if a_shape[-1] != b_shape[-2]:
1676+
raise ValueError(
1677+
"a must have (..., M, M) shape and b must have "
1678+
"(..., M, K) shape"
1679+
)
1680+
1681+
# Use dpnp.broadcast_shapes() to align the resulting batch shapes
1682+
broadcasted_batch_shape = dpnp.broadcast_shapes(
1683+
a_shape[:-2], b_shape[:-2]
16531684
)
16541685

1686+
a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
1687+
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]
1688+
1689+
if a_shape != a_broadcasted_shape:
1690+
a = dpnp.broadcast_to(a, a_broadcasted_shape)
1691+
if b_shape != b_broadcasted_shape:
1692+
b = dpnp.broadcast_to(b, b_broadcasted_shape)
1693+
16551694
return dpnp_solve(a, b)
16561695

16571696

0 commit comments

Comments
 (0)