@@ -1663,20 +1663,26 @@ def test_slogdet(shape, is_empty, device):
1663
1663
1664
1664
1665
1665
@pytest .mark .parametrize (
1666
- "shape, hermitian" ,
1666
+ "shape, hermitian, rcond_as_array " ,
1667
1667
[
1668
- ((4 , 4 ), False ),
1669
- ((2 , 0 ), False ),
1670
- ((4 , 4 ), True ),
1671
- ((2 , 2 , 3 ), False ),
1672
- ((0 , 2 , 3 ), False ),
1673
- ((1 , 0 , 3 ), False ),
1668
+ ((4 , 4 ), False , False ),
1669
+ ((4 , 4 ), False , True ),
1670
+ ((2 , 0 ), False , False ),
1671
+ ((4 , 4 ), True , False ),
1672
+ ((4 , 4 ), True , True ),
1673
+ ((2 , 2 , 3 ), False , False ),
1674
+ ((2 , 2 , 3 ), False , True ),
1675
+ ((0 , 2 , 3 ), False , False ),
1676
+ ((1 , 0 , 3 ), False , False ),
1674
1677
],
1675
1678
ids = [
1676
1679
"(4, 4)" ,
1680
+ "(4, 4), rcond_as_array" ,
1677
1681
"(2, 0)" ,
1678
1682
"(2, 2), hermitian)" ,
1683
+ "(2, 2), hermitian, rcond_as_array)" ,
1679
1684
"(2, 2, 3)" ,
1685
+ "(2, 2, 3), rcond_as_array" ,
1680
1686
"(0, 2, 3)" ,
1681
1687
"(1, 0, 3)" ,
1682
1688
],
@@ -1686,7 +1692,7 @@ def test_slogdet(shape, is_empty, device):
1686
1692
valid_devices ,
1687
1693
ids = [device .filter_string for device in valid_devices ],
1688
1694
)
1689
- def test_pinv (shape , hermitian , device ):
1695
+ def test_pinv (shape , hermitian , rcond_as_array , device ):
1690
1696
if hermitian :
1691
1697
a_np = numpy .random .randn (* shape ) + 1j * numpy .random .randn (* shape )
1692
1698
a_np = numpy .conj (a_np .T ) @ a_np
@@ -1695,8 +1701,20 @@ def test_pinv(shape, hermitian, device):
1695
1701
1696
1702
a_dp = dpnp .array (a_np , device = device )
1697
1703
1698
- B_result = dpnp .linalg .pinv (a_dp , hermitian = hermitian )
1699
- B_expected = numpy .linalg .pinv (a_np , hermitian = hermitian )
1704
+ if rcond_as_array :
1705
+ rcond_np = numpy .array (1e-15 )
1706
+ rcond_dp = dpnp .array (1e-15 , device = device )
1707
+
1708
+ B_result = dpnp .linalg .pinv (a_dp , rcond = rcond_dp , hermitian = hermitian )
1709
+ B_expected = numpy .linalg .pinv (
1710
+ a_np , rcond = rcond_np , hermitian = hermitian
1711
+ )
1712
+
1713
+ else :
1714
+ # rcond == 1e-15 by default
1715
+ B_result = dpnp .linalg .pinv (a_dp , hermitian = hermitian )
1716
+ B_expected = numpy .linalg .pinv (a_np , hermitian = hermitian )
1717
+
1700
1718
assert_allclose (B_expected , B_result , rtol = 1e-3 , atol = 1e-4 )
1701
1719
1702
1720
B_queue = B_result .sycl_queue
0 commit comments