Skip to content

Commit 7f03075

Browse files
Add additional checks for rcond parameter
1 parent b7ee111 commit 7f03075

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@ def pinv(a, rcond=1e-15, hermitian=False):
489489
----------
490490
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
491491
Matrix or stack of matrices to be pseudo-inverted.
492-
rcond : {float, array_like}, optional
492+
rcond : {float, dpnp.ndarray, usm_ndarray}, optional
493493
Cutoff for small singular values.
494494
Singular values less than or equal to ``rcond * largest_singular_value``
495-
are set to zero.
495+
are set to zero. Broadcasts against the stack of matrices.
496496
Default: ``1e-15``.
497497
hermitian : bool, optional
498498
If ``True``, a is assumed to be Hermitian (symmetric if real-valued),
@@ -520,6 +520,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
520520
"""
521521

522522
dpnp.check_supported_arrays_type(a)
523+
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
523524
check_stacked_2d(a)
524525

525526
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,6 @@ def dpnp_pinv(a, rcond=1e-15, hermitian=False):
10101010
10111011
"""
10121012

1013-
rcond = dpnp.array(rcond, usm_type=a.usm_type, sycl_queue=a.sycl_queue)
10141013
if a.size == 0:
10151014
m, n = a.shape[-2:]
10161015
if m == 0 or n == 0:
@@ -1019,6 +1018,14 @@ def dpnp_pinv(a, rcond=1e-15, hermitian=False):
10191018
res_type = _common_type(a)
10201019
return dpnp.empty_like(a, shape=(a.shape[:-2] + (n, m)), dtype=res_type)
10211020

1021+
if dpnp.is_supported_array_type(rcond):
1022+
# Check that `a` and `rcond` are allocated on the same device
1023+
# and have the same queue. Otherwise, `ValueError`` will be raised.
1024+
get_usm_allocations([a, rcond])
1025+
else:
1026+
# Allocate dpnp.ndarray if rcond is a scalar
1027+
rcond = dpnp.array(rcond, usm_type=a.usm_type, sycl_queue=a.sycl_queue)
1028+
10221029
u, s, vt = dpnp_svd(a.conj(), full_matrices=False, hermitian=hermitian)
10231030

10241031
# discard small singular values

tests/test_linalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1289,10 +1289,26 @@ def test_pinv_strides(self):
12891289
def test_pinv_errors(self):
12901290
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")
12911291

1292-
# unsupported type
1292+
# unsupported type `a`
12931293
a_np = inp.asnumpy(a_dp)
12941294
assert_raises(TypeError, inp.linalg.pinv, a_np)
12951295

1296+
# unsupported type `rcond`
1297+
rcond = numpy.array(0.5, dtype="float32")
1298+
assert_raises(TypeError, inp.linalg.pinv, a_dp, rcond)
1299+
assert_raises(TypeError, inp.linalg.pinv, a_dp, [0.5])
1300+
1301+
# non-broadcastable `rcond`
1302+
rcond_dp = inp.array([0.5], dtype="float32")
1303+
assert_raises(ValueError, inp.linalg.pinv, a_dp, rcond_dp)
1304+
12961305
# a.ndim < 2
12971306
a_dp_ndim_1 = a_dp.flatten()
12981307
assert_raises(inp.linalg.LinAlgError, inp.linalg.pinv, a_dp_ndim_1)
1308+
1309+
# diffetent queue
1310+
a_queue = dpctl.SyclQueue()
1311+
rcond_queue = dpctl.SyclQueue()
1312+
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
1313+
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
1314+
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)

0 commit comments

Comments
 (0)