Skip to content

Commit 0766347

Browse files
Update test_sycl_queue
1 parent 34aa37c commit 0766347

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

tests/test_sycl_queue.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,20 +1663,26 @@ def test_slogdet(shape, is_empty, device):
16631663

16641664

16651665
@pytest.mark.parametrize(
1666-
"shape, hermitian",
1666+
"shape, hermitian, rcond_as_array",
16671667
[
1668-
((4, 4), False),
1669-
((2, 0), False),
1670-
((4, 4), True),
1671-
((2, 2, 3), False),
1672-
((0, 2, 3), False),
1673-
((1, 0, 3), False),
1668+
((4, 4), False, False),
1669+
((4, 4), False, True),
1670+
((2, 0), False, False),
1671+
((4, 4), True, False),
1672+
((4, 4), True, True),
1673+
((2, 2, 3), False, False),
1674+
((2, 2, 3), False, True),
1675+
((0, 2, 3), False, False),
1676+
((1, 0, 3), False, False),
16741677
],
16751678
ids=[
16761679
"(4, 4)",
1680+
"(4, 4), rcond_as_array",
16771681
"(2, 0)",
16781682
"(2, 2), hermitian)",
1683+
"(2, 2), hermitian, rcond_as_array)",
16791684
"(2, 2, 3)",
1685+
"(2, 2, 3), rcond_as_array",
16801686
"(0, 2, 3)",
16811687
"(1, 0, 3)",
16821688
],
@@ -1686,7 +1692,7 @@ def test_slogdet(shape, is_empty, device):
16861692
valid_devices,
16871693
ids=[device.filter_string for device in valid_devices],
16881694
)
1689-
def test_pinv(shape, hermitian, device):
1695+
def test_pinv(shape, hermitian, rcond_as_array, device):
16901696
if hermitian:
16911697
a_np = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape)
16921698
a_np = numpy.conj(a_np.T) @ a_np
@@ -1695,8 +1701,20 @@ def test_pinv(shape, hermitian, device):
16951701

16961702
a_dp = dpnp.array(a_np, device=device)
16971703

1698-
B_result = dpnp.linalg.pinv(a_dp, hermitian=hermitian)
1699-
B_expected = numpy.linalg.pinv(a_np, hermitian=hermitian)
1704+
if rcond_as_array:
1705+
rcond_np = numpy.array(1e-15)
1706+
rcond_dp = dpnp.array(1e-15, device=device)
1707+
1708+
B_result = dpnp.linalg.pinv(a_dp, rcond=rcond_dp, hermitian=hermitian)
1709+
B_expected = numpy.linalg.pinv(
1710+
a_np, rcond=rcond_np, hermitian=hermitian
1711+
)
1712+
1713+
else:
1714+
# rcond == 1e-15 by default
1715+
B_result = dpnp.linalg.pinv(a_dp, hermitian=hermitian)
1716+
B_expected = numpy.linalg.pinv(a_np, hermitian=hermitian)
1717+
17001718
assert_allclose(B_expected, B_result, rtol=1e-3, atol=1e-4)
17011719

17021720
B_queue = B_result.sycl_queue

0 commit comments

Comments
 (0)