Skip to content

Commit b7ee111

Browse files
Address remarks
1 parent 0293a3f commit b7ee111

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
489489
----------
490490
a : (..., M, N) {dpnp.ndarray, usm_ndarray}
491491
Matrix or stack of matrices to be pseudo-inverted.
492-
rcond : float or dpnp.ndarray of float, optional
492+
rcond : {float, array_like}, optional
493493
Cutoff for small singular values.
494494
Singular values less than or equal to ``rcond * largest_singular_value``
495495
are set to zero.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,18 +1010,19 @@ def dpnp_pinv(a, rcond=1e-15, hermitian=False):
10101010
10111011
"""
10121012

1013-
rcond = dpnp.array(rcond, device=a.sycl_device, sycl_queue=a.sycl_queue)
1013+
rcond = dpnp.array(rcond, usm_type=a.usm_type, sycl_queue=a.sycl_queue)
10141014
if a.size == 0:
1015-
res_type = _common_type(a)
10161015
m, n = a.shape[-2:]
10171016
if m == 0 or n == 0:
10181017
res_type = a.dtype
1018+
else:
1019+
res_type = _common_type(a)
10191020
return dpnp.empty_like(a, shape=(a.shape[:-2] + (n, m)), dtype=res_type)
10201021

10211022
u, s, vt = dpnp_svd(a.conj(), full_matrices=False, hermitian=hermitian)
10221023

10231024
# discard small singular values
1024-
cutoff = rcond * dpnp.amax(s, axis=-1)
1025+
cutoff = rcond * dpnp.max(s, axis=-1)
10251026
leq = s <= cutoff[..., None]
10261027
dpnp.reciprocal(s, out=s)
10271028
s[leq] = 0

tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def test_pinv(shape, hermitian, usm_type):
880880
["r", "raw", "complete", "reduced"],
881881
ids=["r", "raw", "complete", "reduced"],
882882
)
883-
def test_pinv(shape, mode, usm_type):
883+
def test_qr(shape, mode, usm_type):
884884
count_elems = numpy.prod(shape)
885885
a = dp.arange(count_elems, usm_type=usm_type).reshape(shape)
886886

0 commit comments

Comments
 (0)