@@ -1612,12 +1612,12 @@ def solve(a, b):
1612
1612
----------
1613
1613
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
1614
1614
Coefficient matrix.
1615
- b : {(…, M,), (… , M, K)} {dpnp.ndarray, usm_ndarray}
1615
+ b : {(M,), (... , M, K)} {dpnp.ndarray, usm_ndarray}
1616
1616
Ordinate or "dependent variable" values.
1617
1617
1618
1618
Returns
1619
1619
-------
1620
- out : {(… , M,), (… , M, K)} dpnp.ndarray
1620
+ out : {(... , M,), (... , M, K)} dpnp.ndarray
1621
1621
Solution to the system `ax = b`. Returned shape is identical to `b`.
1622
1622
1623
1623
See Also
@@ -1644,14 +1644,53 @@ def solve(a, b):
1644
1644
assert_stacked_2d (a )
1645
1645
assert_stacked_square (a )
1646
1646
1647
- if not (
1648
- a .ndim in [b .ndim , b .ndim + 1 ] and a .shape [:- 1 ] == b .shape [: a .ndim - 1 ]
1649
- ):
1650
- raise dpnp .linalg .LinAlgError (
1651
- "a must have (..., M, M) shape and b must have (..., M) "
1652
- "or (..., M, K)"
1647
+ a_ndim = a .ndim
1648
+ b_ndim = b .ndim
1649
+
1650
+ a_shape = a .shape
1651
+ b_shape = b .shape
1652
+
1653
+ if numpy .lib .NumpyVersion (numpy .__version__ ) < "2.0.0" :
1654
+ if not (
1655
+ a_ndim in [b_ndim , b_ndim + 1 ]
1656
+ and a_shape [:- 1 ] == b_shape [: a_ndim - 1 ]
1657
+ ):
1658
+ raise dpnp .linalg .LinAlgError (
1659
+ "a must have (..., M, M) shape and b must have (..., M) "
1660
+ "or (..., M, K)"
1661
+ )
1662
+
1663
+ else : # compatible with numpy>=2.0
1664
+ if b_ndim == 0 :
1665
+ raise ValueError ("b must have at least one dimension" )
1666
+ if b_ndim == 1 :
1667
+ if a_shape [- 1 ] != b .size :
1668
+ raise ValueError (
1669
+ "a must have (..., M, M) shape and b must have (M,) "
1670
+ "for one-dimensional b"
1671
+ )
1672
+ b = dpnp .broadcast_to (b , a_shape [:- 1 ])
1673
+ return dpnp_solve (a , b )
1674
+
1675
+ if a_shape [- 1 ] != b_shape [- 2 ]:
1676
+ raise ValueError (
1677
+ "a must have (..., M, M) shape and b must have "
1678
+ "(..., M, K) shape"
1679
+ )
1680
+
1681
+ # Use dpnp.broadcast_shapes() to align the resulting batch shapes
1682
+ broadcasted_batch_shape = dpnp .broadcast_shapes (
1683
+ a_shape [:- 2 ], b_shape [:- 2 ]
1653
1684
)
1654
1685
1686
+ a_broadcasted_shape = broadcasted_batch_shape + a_shape [- 2 :]
1687
+ b_broadcasted_shape = broadcasted_batch_shape + b_shape [- 2 :]
1688
+
1689
+ if a_shape != a_broadcasted_shape :
1690
+ a = dpnp .broadcast_to (a , a_broadcasted_shape )
1691
+ if b_shape != b_broadcasted_shape :
1692
+ b = dpnp .broadcast_to (b , b_broadcasted_shape )
1693
+
1655
1694
return dpnp_solve (a , b )
1656
1695
1657
1696
0 commit comments