Skip to content

Commit 41831be

Browse files
Update cupy tests for dpnp.linalg.solve()
1 parent d2ef11a commit 41831be

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,20 @@ def check_x(self, a_shape, b_shape, xp, dtype):
5050
def test_solve(self):
5151
self.check_x((4, 4), (4,))
5252
self.check_x((5, 5), (5, 2))
53-
self.check_x((2, 4, 4), (2, 4))
5453
self.check_x((2, 5, 5), (2, 5, 2))
55-
self.check_x((2, 3, 2, 2), (2, 3, 2))
5654
self.check_x((2, 3, 3, 3), (2, 3, 3, 2))
5755
self.check_x((0, 0), (0,))
5856
self.check_x((0, 0), (0, 2))
59-
self.check_x((0, 2, 2), (0, 2))
6057
self.check_x((0, 2, 2), (0, 2, 3))
58+
# In numpy 2.0 the broadcast ambiguity has been removed and now
59+
# b is treaded as a single vector if and only if it is 1-dimensional;
60+
# for other cases this signature must be followed
61+
# (..., m, m), (..., m, n) -> (..., m, n)
62+
# https://github.com/numpy/numpy/pull/25914
63+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
64+
self.check_x((2, 4, 4), (2, 4))
65+
self.check_x((2, 3, 2, 2), (2, 3, 2))
66+
self.check_x((0, 2, 2), (0, 2))
6167

6268
def check_shape(self, a_shape, b_shape, error_types):
6369
for xp, error_type in error_types.items():
@@ -90,7 +96,9 @@ def test_invalid_shape(self):
9096
self.check_shape((3, 3), (2,), value_errors)
9197
self.check_shape((3, 3), (2, 2), value_errors)
9298
self.check_shape((3, 3, 4), (3,), linalg_errors)
93-
self.check_shape((2, 3, 3), (3,), value_errors)
99+
# Since numpy >= 2.0, this case does not raise an error
100+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101+
self.check_shape((2, 3, 3), (3,), value_errors)
94102
self.check_shape((3, 3), (0,), value_errors)
95103
self.check_shape((0, 3, 4), (3,), linalg_errors)
96104

0 commit comments

Comments
 (0)