Skip to content

Commit c03ba5c

Browse files
Update and add more tests for solve()
1 parent 2ebe7eb commit c03ba5c

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

dpnp/tests/test_linalg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,6 +2694,36 @@ def test_solve(self, dtype):
26942694

26952695
assert_allclose(expected, result, rtol=1e-06)
26962696

2697+
@testing.with_requires("numpy>=2.0")
2698+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
2699+
@pytest.mark.parametrize(
2700+
"a_shape, b_shape",
2701+
[
2702+
((4, 4), (2, 2, 4, 3)),
2703+
((2, 5, 5), (1, 5, 3)),
2704+
((2, 4, 4), (2, 2, 4, 2)),
2705+
((3, 2, 2), (3, 1, 2, 1)),
2706+
((2, 2, 2, 2, 2), (2,)),
2707+
((2, 2, 2, 2, 2), (2, 3)),
2708+
],
2709+
)
2710+
def test_solve_broadcast(self, a_shape, b_shape, dtype):
2711+
# Set seed_value=81 to prevent
2712+
# random generation of the input singular matrix
2713+
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)
2714+
2715+
# Set seed_value=76 to prevent
2716+
# random generation of the input singular matrix
2717+
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)
2718+
2719+
a_dp = inp.array(a_np)
2720+
b_dp = inp.array(b_np)
2721+
2722+
expected = numpy.linalg.solve(a_np, b_np)
2723+
result = inp.linalg.solve(a_dp, b_dp)
2724+
2725+
assert_dtype_allclose(result, expected)
2726+
26972727
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
26982728
def test_solve_nrhs_greater_n(self, dtype):
26992729
# Test checking the case when nrhs > n for

dpnp/tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def test_solve(self):
6464
self.check_x((2, 4, 4), (2, 4))
6565
self.check_x((2, 3, 2, 2), (2, 3, 2))
6666
self.check_x((0, 2, 2), (0, 2))
67+
else: # Allowed since numpy 2.0
68+
self.check_x((2, 3, 3), (3,))
69+
self.check_x((2, 5, 3, 3), (3,))
6770

6871
def check_shape(self, a_shape, b_shape, error_types):
6972
for xp, error_type in error_types.items():
@@ -96,11 +99,38 @@ def test_invalid_shape(self):
9699
self.check_shape((3, 3), (2,), value_errors)
97100
self.check_shape((3, 3), (2, 2), value_errors)
98101
self.check_shape((3, 3, 4), (3,), linalg_errors)
102+
self.check_shape((3, 3), (0,), value_errors)
103+
self.check_shape((0, 3, 4), (3,), linalg_errors)
99104
# Since numpy >= 2.0, this case does not raise an error
100105
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
101106
self.check_shape((2, 3, 3), (3,), value_errors)
102-
self.check_shape((3, 3), (0,), value_errors)
103-
self.check_shape((0, 3, 4), (3,), linalg_errors)
107+
else:
108+
# Not allowed since numpy 2
109+
self.check_shape(
110+
(0, 2, 2),
111+
(
112+
0,
113+
2,
114+
),
115+
value_errors,
116+
)
117+
self.check_shape(
118+
(2, 4, 4),
119+
(
120+
2,
121+
4,
122+
),
123+
value_errors,
124+
)
125+
self.check_shape(
126+
(2, 3, 2, 2),
127+
(
128+
2,
129+
3,
130+
2,
131+
),
132+
value_errors,
133+
)
104134

105135

106136
@testing.parameterize(

0 commit comments

Comments
 (0)