Skip to content

Commit 791def5

Browse files
committed
address comments
1 parent b646cf1 commit 791def5

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,7 @@ def matrix_rank(A, tol=None, hermitian=False, *, rtol=None):
10731073
----------
10741074
A : {(M,), (..., M, N)} {dpnp.ndarray, usm_ndarray}
10751075
Input vector or stack of matrices.
1076-
tol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
1076+
tol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
10771077
Threshold below which SVD values are considered zero. Only `tol` or
10781078
`rtol` can be set at a time. If none of them are provided, defaults
10791079
to ``S.max() * max(M, N) * eps`` where `S` is an array with singular
@@ -1083,7 +1083,7 @@ def matrix_rank(A, tol=None, hermitian=False, *, rtol=None):
10831083
If ``True``, `A` is assumed to be Hermitian (symmetric if real-valued),
10841084
enabling a more efficient method for finding singular values.
10851085
Default: ``False``.
1086-
rtol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
1086+
rtol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
10871087
Parameter for the relative tolerance component. Only `tol` or `rtol`
10881088
can be set at a time. If none of them are provided, defaults to
10891089
``max(M, N) * eps`` where `eps` is the epsilon value for datatype
@@ -1479,7 +1479,7 @@ def pinv(a, rcond=None, hermitian=False, *, rtol=None):
14791479
----------
14801480
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
14811481
Matrix or stack of matrices to be pseudo-inverted.
1482-
rcond : (...) {float, dpnp.ndarray, usm_ndarray}, optional
1482+
rcond : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
14831483
Cutoff for small singular values.
14841484
Singular values less than or equal to ``rcond * largest_singular_value``
14851485
are set to zero. Broadcasts against the stack of matrices.
@@ -1490,7 +1490,7 @@ def pinv(a, rcond=None, hermitian=False, *, rtol=None):
14901490
If ``True``, a is assumed to be Hermitian (symmetric if real-valued),
14911491
enabling a more efficient method for finding singular values.
14921492
Default: ``False``.
1493-
rtol : (...) {float, dpnp.ndarray, usm_ndarray}, optional
1493+
rtol : (...) {None, float, dpnp.ndarray, usm_ndarray}, optional
14941494
Same as `rcond`, but it's an Array API compatible parameter name.
14951495
Only `rcond` or `rtol` can be set at a time. If none of them are
14961496
provided, defaults to ``max(M, N) * dpnp.finfo(a.dtype).eps``.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2234,10 +2234,13 @@ def dpnp_matrix_rank(A, tol=None, hermitian=False, rtol=None):
22342234
if rtol is None:
22352235
rtol = max(A.shape[-2:]) * dpnp.finfo(S.dtype).eps
22362236
elif not dpnp.isscalar(rtol):
2237+
# Add a new axis to make it broadcastable against S
2238+
# needed for S > tol comparison below
22372239
rtol = rtol[..., None]
22382240
tol = S.max(axis=-1, keepdims=True) * rtol
22392241
elif not dpnp.isscalar(tol):
2240-
# Add a new axis to match NumPy's output
2242+
# Add a new axis to make it broadcastable against S,
2243+
# needed for S > tol comparison below
22412244
tol = tol[..., None]
22422245

22432246
return dpnp.count_nonzero(S > tol, axis=-1)

tests/test_linalg.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,25 +2100,28 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
21002100
# rtol kwarg was added in numpy 2.0
21012101
@testing.with_requires("numpy>=2.0")
21022102
@pytest.mark.parametrize(
2103-
"rtol",
2104-
[0.99e-6, numpy.array(1.01e-6), numpy.array([0.99e-6])],
2103+
"tol",
2104+
[0.99e-6, numpy.array(1.01e-6), numpy.ones(4) * [0.99e-6]],
21052105
ids=["float", "0-D array", "1-D array"],
21062106
)
2107-
def test_matrix_rank_rtol(self, rtol):
2108-
a = numpy.eye(4)
2109-
a[-1, -1] = 1e-6
2107+
def test_matrix_rank_tol(self, tol):
2108+
a = numpy.zeros((4, 3, 2))
21102109
a_dp = inp.array(a)
21112110

2112-
if isinstance(rtol, numpy.ndarray):
2113-
dp_rtol = inp.array(
2114-
rtol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue
2111+
if isinstance(tol, numpy.ndarray):
2112+
dp_tol = inp.array(
2113+
tol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue
21152114
)
21162115
else:
2117-
dp_rtol = rtol
2116+
dp_tol = tol
21182117

2119-
expected = numpy.linalg.matrix_rank(a, rtol=rtol)
2120-
result = inp.linalg.matrix_rank(a_dp, rtol=dp_rtol)
2121-
assert expected == result
2118+
expected = numpy.linalg.matrix_rank(a, rtol=tol)
2119+
result = inp.linalg.matrix_rank(a_dp, rtol=dp_tol)
2120+
assert_dtype_allclose(result, expected)
2121+
2122+
expected = numpy.linalg.matrix_rank(a, tol=tol)
2123+
result = inp.linalg.matrix_rank(a_dp, tol=dp_tol)
2124+
assert_dtype_allclose(result, expected)
21222125

21232126
def test_matrix_rank_errors(self):
21242127
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")

0 commit comments

Comments
 (0)