Skip to content

Commit 7b3c48c

Browse files
Update cupy tests for dpnp.linalg.solve() (#2074)
* Update cupy tests for dpnp.linalg.solve() * Update CHANGELOG.md
1 parent bd6546e commit 7b3c48c

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds
107107
* Improved implementation of `dpnp.kron` to avoid unnecessary copy for non-contiguous arrays [#2059](https://github.com/IntelPython/dpnp/pull/2059)
108108
* Updated the test suit for `dpnp.fft` module [#2071](https://github.com/IntelPython/dpnp/pull/2071)
109109
* Reworked `dpnp.clip` implementation to align with Python Array API 2023.12 specification [#2048](https://github.com/IntelPython/dpnp/pull/2048)
110+
* Skipped outdated tests for `dpnp.linalg.solve` due to compatibility issues with NumPy 2.0 [#2074](https://github.com/IntelPython/dpnp/pull/2074)
110111

111112
### Fixed
112113

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,20 @@ def check_x(self, a_shape, b_shape, xp, dtype):
5050
def test_solve(self):
5151
self.check_x((4, 4), (4,))
5252
self.check_x((5, 5), (5, 2))
53-
self.check_x((2, 4, 4), (2, 4))
5453
self.check_x((2, 5, 5), (2, 5, 2))
55-
self.check_x((2, 3, 2, 2), (2, 3, 2))
5654
self.check_x((2, 3, 3, 3), (2, 3, 3, 2))
5755
self.check_x((0, 0), (0,))
5856
self.check_x((0, 0), (0, 2))
59-
self.check_x((0, 2, 2), (0, 2))
6057
self.check_x((0, 2, 2), (0, 2, 3))
58+
# In numpy 2.0 the broadcast ambiguity has been removed and now
59+
# b is treaded as a single vector if and only if it is 1-dimensional;
60+
# for other cases this signature must be followed
61+
# (..., m, m), (..., m, n) -> (..., m, n)
62+
# https://github.com/numpy/numpy/pull/25914
63+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
64+
self.check_x((2, 4, 4), (2, 4))
65+
self.check_x((2, 3, 2, 2), (2, 3, 2))
66+
self.check_x((0, 2, 2), (0, 2))
6167

6268
def check_shape(self, a_shape, b_shape, error_types):
6369
for xp, error_type in error_types.items():
@@ -90,7 +96,9 @@ def test_invalid_shape(self):
9096
self.check_shape((3, 3), (2,), value_errors)
9197
self.check_shape((3, 3), (2, 2), value_errors)
9298
self.check_shape((3, 3, 4), (3,), linalg_errors)
93-
self.check_shape((2, 3, 3), (3,), value_errors)
99+
# Since numpy >= 2.0, this case does not raise an error
100+
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101+
self.check_shape((2, 3, 3), (3,), value_errors)
94102
self.check_shape((3, 3), (0,), value_errors)
95103
self.check_shape((0, 3, 4), (3,), linalg_errors)
96104

0 commit comments

Comments
 (0)