Skip to content

Commit f73bcf3

Browse files
committed
Update linalg tests
1 parent 5ba2f41 commit f73bcf3

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

dpnp/tests/test_linalg.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):
19351935

19361936
np_rank = numpy.linalg.matrix_rank(a)
19371937
dp_rank = dpnp.linalg.matrix_rank(a_dp)
1938-
assert np_rank == dp_rank
1938+
assert dp_rank.asnumpy() == np_rank
19391939

19401940
@pytest.mark.parametrize("dtype", get_all_dtypes())
19411941
@pytest.mark.parametrize(
@@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):
19531953

19541954
np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
19551955
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
1956-
assert np_rank == dp_rank
1956+
assert dp_rank.asnumpy() == np_rank
19571957

19581958
@pytest.mark.parametrize(
19591959
"high_tol, low_tol",
@@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
19861986
dp_rank_high_tol = dpnp.linalg.matrix_rank(
19871987
a_dp, hermitian=True, tol=dp_high_tol
19881988
)
1989-
assert np_rank_high_tol == dp_rank_high_tol
1989+
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol
19901990

19911991
np_rank_low_tol = numpy.linalg.matrix_rank(
19921992
a, hermitian=True, tol=low_tol
19931993
)
19941994
dp_rank_low_tol = dpnp.linalg.matrix_rank(
19951995
a_dp, hermitian=True, tol=dp_low_tol
19961996
)
1997-
assert np_rank_low_tol == dp_rank_low_tol
1997+
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol
19981998

19991999
# rtol kwarg was added in numpy 2.0
20002000
@testing.with_requires("numpy>=2.0")
@@ -2789,15 +2789,14 @@ def check_decomposition(
27892789
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
27902790
dpnp_diag_s[..., i, i] = dp_s[..., i]
27912791
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
2792-
# TODO: use assert dpnp.allclose() inside check_decomposition()
2793-
# when it will support complex dtypes
2794-
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
2792+
2793+
assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
27952794

27962795
assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)
27972796

27982797
if compute_vt:
27992798
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
2800-
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
2799+
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
28012800
np_u[..., :, i] = -np_u[..., :, i]
28022801
np_vt[..., i, :] = -np_vt[..., i, :]
28032802
for i in range(numpy.count_nonzero(np_s > tol)):

0 commit comments

Comments
 (0)