@@ -50,14 +50,20 @@ def check_x(self, a_shape, b_shape, xp, dtype):
50
50
def test_solve (self ):
51
51
self .check_x ((4 , 4 ), (4 ,))
52
52
self .check_x ((5 , 5 ), (5 , 2 ))
53
- self .check_x ((2 , 4 , 4 ), (2 , 4 ))
54
53
self .check_x ((2 , 5 , 5 ), (2 , 5 , 2 ))
55
- self .check_x ((2 , 3 , 2 , 2 ), (2 , 3 , 2 ))
56
54
self .check_x ((2 , 3 , 3 , 3 ), (2 , 3 , 3 , 2 ))
57
55
self .check_x ((0 , 0 ), (0 ,))
58
56
self .check_x ((0 , 0 ), (0 , 2 ))
59
- self .check_x ((0 , 2 , 2 ), (0 , 2 ))
60
57
self .check_x ((0 , 2 , 2 ), (0 , 2 , 3 ))
58
+ # In numpy 2.0 the broadcast ambiguity has been removed and now
59
+ # b is treaded as a single vector if and only if it is 1-dimensional;
60
+ # for other cases this signature must be followed
61
+ # (..., m, m), (..., m, n) -> (..., m, n)
62
+ # https://github.com/numpy/numpy/pull/25914
63
+ if numpy .lib .NumpyVersion (numpy .__version__ ) < "2.0.0" :
64
+ self .check_x ((2 , 4 , 4 ), (2 , 4 ))
65
+ self .check_x ((2 , 3 , 2 , 2 ), (2 , 3 , 2 ))
66
+ self .check_x ((0 , 2 , 2 ), (0 , 2 ))
61
67
62
68
def check_shape (self , a_shape , b_shape , error_types ):
63
69
for xp , error_type in error_types .items ():
@@ -90,7 +96,9 @@ def test_invalid_shape(self):
90
96
self .check_shape ((3 , 3 ), (2 ,), value_errors )
91
97
self .check_shape ((3 , 3 ), (2 , 2 ), value_errors )
92
98
self .check_shape ((3 , 3 , 4 ), (3 ,), linalg_errors )
93
- self .check_shape ((2 , 3 , 3 ), (3 ,), value_errors )
99
+ # Since numpy >= 2.0, this case does not raise an error
100
+ if numpy .lib .NumpyVersion (numpy .__version__ ) < "2.0.0" :
101
+ self .check_shape ((2 , 3 , 3 ), (3 ,), value_errors )
94
102
self .check_shape ((3 , 3 ), (0 ,), value_errors )
95
103
self .check_shape ((0 , 3 , 4 ), (3 ,), linalg_errors )
96
104
0 commit comments