Skip to content

Commit 6416534

Browse files
committed
Update linalg tests
1 parent e6dda79 commit 6416534

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

dpnp/tests/test_linalg.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):
19351935

19361936
np_rank = numpy.linalg.matrix_rank(a)
19371937
dp_rank = dpnp.linalg.matrix_rank(a_dp)
1938-
assert np_rank == dp_rank
1938+
assert dp_rank.asnumpy() == np_rank
19391939

19401940
@pytest.mark.parametrize("dtype", get_all_dtypes())
19411941
@pytest.mark.parametrize(
@@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):
19531953

19541954
np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
19551955
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
1956-
assert np_rank == dp_rank
1956+
assert dp_rank.asnumpy() == np_rank
19571957

19581958
@pytest.mark.parametrize(
19591959
"high_tol, low_tol",
@@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
19861986
dp_rank_high_tol = dpnp.linalg.matrix_rank(
19871987
a_dp, hermitian=True, tol=dp_high_tol
19881988
)
1989-
assert np_rank_high_tol == dp_rank_high_tol
1989+
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol
19901990

19911991
np_rank_low_tol = numpy.linalg.matrix_rank(
19921992
a, hermitian=True, tol=low_tol
19931993
)
19941994
dp_rank_low_tol = dpnp.linalg.matrix_rank(
19951995
a_dp, hermitian=True, tol=dp_low_tol
19961996
)
1997-
assert np_rank_low_tol == dp_rank_low_tol
1997+
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol
19981998

19991999
# rtol kwarg was added in numpy 2.0
20002000
@testing.with_requires("numpy>=2.0")
@@ -2807,15 +2807,14 @@ def check_decomposition(
28072807
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
28082808
dpnp_diag_s[..., i, i] = dp_s[..., i]
28092809
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
2810-
# TODO: use assert dpnp.allclose() inside check_decomposition()
2811-
# when it will support complex dtypes
2812-
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
2810+
2811+
assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
28132812

28142813
assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)
28152814

28162815
if compute_vt:
28172816
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
2818-
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
2817+
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
28192818
np_u[..., :, i] = -np_u[..., :, i]
28202819
np_vt[..., i, :] = -np_vt[..., i, :]
28212820
for i in range(numpy.count_nonzero(np_s > tol)):

0 commit comments

Comments
 (0)