Skip to content

Commit 0be3132

Browse files
Add hermitian argument support
1 parent 33c4f5e commit 0be3132

File tree

2 files changed

+127
-106
lines changed

2 files changed

+127
-106
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,9 +645,4 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
645645
dpnp.check_supported_arrays_type(a)
646646
check_stacked_2d(a)
647647

648-
if hermitian is True:
649-
raise NotImplementedError(
650-
"hermitian keyword argument is only supported with its default value."
651-
)
652-
653-
return dpnp_svd(a, full_matrices, compute_uv)
648+
return dpnp_svd(a, full_matrices, compute_uv, hermitian)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 126 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -595,13 +595,40 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
595595
return out_s
596596

597597

598-
def dpnp_svd(a, full_matrices=True, compute_uv=True):
598+
def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
599599
"""
600600
dpnp_svd(a)
601601
602602
Return the singular value decomposition (SVD).
603603
"""
604604

605+
if hermitian:
606+
check_stacked_square(a)
607+
608+
# _gesvd returns eigenvalues with s ** 2 sorted descending,
609+
# but dpnp.linalg.eigh returns s sorted ascending so we re-order the eigenvalues
610+
# and related arrays to have the correct order
611+
if compute_uv:
612+
s, u = dpnp.linalg.eigh(a)
613+
sgn = dpnp.sign(s)
614+
s = dpnp.absolute(s)
615+
sidx = dpnp.argsort(s)[..., ::-1]
616+
# Rearrange the signs according to sorted indices
617+
sgn = dpnp.take_along_axis(sgn, sidx, axis=-1)
618+
# Sort the singular values in descending order
619+
s = dpnp.take_along_axis(s, sidx, axis=-1)
620+
# Rearrange the eigenvectors according to sorted indices
621+
u = dpnp.take_along_axis(u, sidx[..., None, :], axis=-1)
622+
# Singular values are unsigned, move the sign into v
623+
# Compute V^T adjusting for the sign and conjugating
624+
vt = dpnp.transpose(u * sgn[..., None, :]).conjugate()
625+
return u, s, vt
626+
else:
627+
# TODO: use dpnp.linalg.eighvals when it is updated
628+
s, _ = dpnp.linalg.eigh(a)
629+
s = dpnp.abs(s)
630+
return dpnp.sort(s)[..., ::-1]
631+
605632
a_usm_type = a.usm_type
606633
a_sycl_queue = a.sycl_queue
607634

@@ -611,113 +638,112 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
611638
if a.ndim > 2:
612639
return dpnp_svd_batch(a, uv_type, s_type, full_matrices, compute_uv)
613640

614-
else:
615-
n, m = a.shape
641+
n, m = a.shape
616642

617-
if m == 0 or n == 0:
618-
s = dpnp.empty(
619-
(0,),
620-
dtype=s_type,
621-
usm_type=a_usm_type,
622-
sycl_queue=a_sycl_queue,
623-
)
624-
if compute_uv:
625-
if full_matrices:
626-
u_shape = (n,)
627-
vt_shape = (m,)
628-
else:
629-
u_shape = (n, 0)
630-
vt_shape = (0, m)
631-
632-
u = dpnp.eye(
633-
*u_shape,
634-
dtype=uv_type,
635-
usm_type=a_usm_type,
636-
sycl_queue=a_sycl_queue,
637-
)
638-
vt = dpnp.eye(
639-
*vt_shape,
640-
dtype=uv_type,
641-
usm_type=a_usm_type,
642-
sycl_queue=a_sycl_queue,
643-
)
644-
return u, s, vt
645-
else:
646-
return s
647-
648-
# `a` must be copied because gesvd destroys the input matrix
649-
# `a` must be traspotted if m < n
650-
if m >= n:
651-
x = a
652-
a_h = dpnp.empty_like(a, order="C", dtype=uv_type)
653-
trans_flag = False
654-
else:
655-
m, n = a.shape
656-
x = a.transpose()
657-
a_h = dpnp.empty_like(x, order="C", dtype=uv_type)
658-
trans_flag = True
659-
660-
a_usm_arr = dpnp.get_usm_ndarray(x)
661-
662-
# use DPCTL tensor function to fill the сopy of the input array
663-
# from the input array
664-
a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
665-
src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue
643+
if m == 0 or n == 0:
644+
s = dpnp.empty(
645+
(0,),
646+
dtype=s_type,
647+
usm_type=a_usm_type,
648+
sycl_queue=a_sycl_queue,
666649
)
667-
668-
k = n # = min(m, n) where m >= n is ensured above
669650
if compute_uv:
670651
if full_matrices:
671-
u_shape = (m, m)
672-
vt_shape = (n, n)
673-
jobu = ord("A")
674-
jobvt = ord("A")
652+
u_shape = (n,)
653+
vt_shape = (m,)
675654
else:
676-
u_shape = x.shape
677-
vt_shape = (k, n)
678-
jobu = ord("S")
679-
jobvt = ord("S")
655+
u_shape = (n, 0)
656+
vt_shape = (0, m)
657+
658+
u = dpnp.eye(
659+
*u_shape,
660+
dtype=uv_type,
661+
usm_type=a_usm_type,
662+
sycl_queue=a_sycl_queue,
663+
)
664+
vt = dpnp.eye(
665+
*vt_shape,
666+
dtype=uv_type,
667+
usm_type=a_usm_type,
668+
sycl_queue=a_sycl_queue,
669+
)
670+
return u, s, vt
680671
else:
681-
u_shape = vt_shape = ()
682-
jobu = ord("N")
683-
jobvt = ord("N")
672+
return s
684673

685-
u_h = dpnp.empty(
686-
u_shape,
687-
dtype=uv_type,
688-
usm_type=a_usm_type,
689-
sycl_queue=a_sycl_queue,
690-
)
691-
vt_h = dpnp.empty(
692-
vt_shape,
693-
dtype=uv_type,
694-
usm_type=a_usm_type,
695-
sycl_queue=a_sycl_queue,
696-
)
697-
s_h = dpnp.empty(
698-
k, dtype=s_type, usm_type=a_usm_type, sycl_queue=a_sycl_queue
699-
)
674+
# `a` must be copied because gesvd destroys the input matrix
675+
# `a` must be traspotted if m < n
676+
if m >= n:
677+
x = a
678+
a_h = dpnp.empty_like(a, order="C", dtype=uv_type)
679+
trans_flag = False
680+
else:
681+
m, n = a.shape
682+
x = a.transpose()
683+
a_h = dpnp.empty_like(x, order="C", dtype=uv_type)
684+
trans_flag = True
700685

701-
ht_lapack_ev, _ = li._gesvd(
702-
a_sycl_queue,
703-
jobu,
704-
jobvt,
705-
m,
706-
n,
707-
a_h.get_array(),
708-
s_h.get_array(),
709-
u_h.get_array(),
710-
vt_h.get_array(),
711-
[a_copy_ev],
712-
)
686+
a_usm_arr = dpnp.get_usm_ndarray(x)
713687

714-
ht_lapack_ev.wait()
715-
a_ht_copy_ev.wait()
688+
# use DPCTL tensor function to fill the сopy of the input array
689+
# from the input array
690+
a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
691+
src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue
692+
)
716693

717-
if compute_uv:
718-
if trans_flag:
719-
return u_h.transpose(), s_h, vt_h.transpose()
720-
else:
721-
return vt_h, s_h, u_h
694+
k = n # = min(m, n) where m >= n is ensured above
695+
if compute_uv:
696+
if full_matrices:
697+
u_shape = (m, m)
698+
vt_shape = (n, n)
699+
jobu = ord("A")
700+
jobvt = ord("A")
722701
else:
723-
return s_h
702+
u_shape = x.shape
703+
vt_shape = (k, n)
704+
jobu = ord("S")
705+
jobvt = ord("S")
706+
else:
707+
u_shape = vt_shape = ()
708+
jobu = ord("N")
709+
jobvt = ord("N")
710+
711+
u_h = dpnp.empty(
712+
u_shape,
713+
dtype=uv_type,
714+
usm_type=a_usm_type,
715+
sycl_queue=a_sycl_queue,
716+
)
717+
vt_h = dpnp.empty(
718+
vt_shape,
719+
dtype=uv_type,
720+
usm_type=a_usm_type,
721+
sycl_queue=a_sycl_queue,
722+
)
723+
s_h = dpnp.empty(
724+
k, dtype=s_type, usm_type=a_usm_type, sycl_queue=a_sycl_queue
725+
)
726+
727+
ht_lapack_ev, _ = li._gesvd(
728+
a_sycl_queue,
729+
jobu,
730+
jobvt,
731+
m,
732+
n,
733+
a_h.get_array(),
734+
s_h.get_array(),
735+
u_h.get_array(),
736+
vt_h.get_array(),
737+
[a_copy_ev],
738+
)
739+
740+
ht_lapack_ev.wait()
741+
a_ht_copy_ev.wait()
742+
743+
if compute_uv:
744+
if trans_flag:
745+
return u_h.transpose(), s_h, vt_h.transpose()
746+
else:
747+
return vt_h, s_h, u_h
748+
else:
749+
return s_h

0 commit comments

Comments
 (0)