@@ -595,13 +595,40 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
595
595
return out_s
596
596
597
597
598
- def dpnp_svd (a , full_matrices = True , compute_uv = True ):
598
+ def dpnp_svd (a , full_matrices = True , compute_uv = True , hermitian = False ):
599
599
"""
600
600
dpnp_svd(a)
601
601
602
602
Return the singular value decomposition (SVD).
603
603
"""
604
604
605
+ if hermitian :
606
+ check_stacked_square (a )
607
+
608
+ # _gesvd returns eigenvalues with s ** 2 sorted descending,
609
+ # but dpnp.linalg.eigh returns s sorted ascending so we re-order the eigenvalues
610
+ # and related arrays to have the correct order
611
+ if compute_uv :
612
+ s , u = dpnp .linalg .eigh (a )
613
+ sgn = dpnp .sign (s )
614
+ s = dpnp .absolute (s )
615
+ sidx = dpnp .argsort (s )[..., ::- 1 ]
616
+ # Rearrange the signs according to sorted indices
617
+ sgn = dpnp .take_along_axis (sgn , sidx , axis = - 1 )
618
+ # Sort the singular values in descending order
619
+ s = dpnp .take_along_axis (s , sidx , axis = - 1 )
620
+ # Rearrange the eigenvectors according to sorted indices
621
+ u = dpnp .take_along_axis (u , sidx [..., None , :], axis = - 1 )
622
+ # Singular values are unsigned, move the sign into v
623
+ # Compute V^T adjusting for the sign and conjugating
624
+ vt = dpnp .transpose (u * sgn [..., None , :]).conjugate ()
625
+ return u , s , vt
626
+ else :
627
+ # TODO: use dpnp.linalg.eighvals when it is updated
628
+ s , _ = dpnp .linalg .eigh (a )
629
+ s = dpnp .abs (s )
630
+ return dpnp .sort (s )[..., ::- 1 ]
631
+
605
632
a_usm_type = a .usm_type
606
633
a_sycl_queue = a .sycl_queue
607
634
@@ -611,113 +638,112 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True):
611
638
if a .ndim > 2 :
612
639
return dpnp_svd_batch (a , uv_type , s_type , full_matrices , compute_uv )
613
640
614
- else :
615
- n , m = a .shape
641
+ n , m = a .shape
616
642
617
- if m == 0 or n == 0 :
618
- s = dpnp .empty (
619
- (0 ,),
620
- dtype = s_type ,
621
- usm_type = a_usm_type ,
622
- sycl_queue = a_sycl_queue ,
623
- )
624
- if compute_uv :
625
- if full_matrices :
626
- u_shape = (n ,)
627
- vt_shape = (m ,)
628
- else :
629
- u_shape = (n , 0 )
630
- vt_shape = (0 , m )
631
-
632
- u = dpnp .eye (
633
- * u_shape ,
634
- dtype = uv_type ,
635
- usm_type = a_usm_type ,
636
- sycl_queue = a_sycl_queue ,
637
- )
638
- vt = dpnp .eye (
639
- * vt_shape ,
640
- dtype = uv_type ,
641
- usm_type = a_usm_type ,
642
- sycl_queue = a_sycl_queue ,
643
- )
644
- return u , s , vt
645
- else :
646
- return s
647
-
648
- # `a` must be copied because gesvd destroys the input matrix
649
- # `a` must be traspotted if m < n
650
- if m >= n :
651
- x = a
652
- a_h = dpnp .empty_like (a , order = "C" , dtype = uv_type )
653
- trans_flag = False
654
- else :
655
- m , n = a .shape
656
- x = a .transpose ()
657
- a_h = dpnp .empty_like (x , order = "C" , dtype = uv_type )
658
- trans_flag = True
659
-
660
- a_usm_arr = dpnp .get_usm_ndarray (x )
661
-
662
- # use DPCTL tensor function to fill the сopy of the input array
663
- # from the input array
664
- a_ht_copy_ev , a_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
665
- src = a_usm_arr , dst = a_h .get_array (), sycl_queue = a_sycl_queue
643
+ if m == 0 or n == 0 :
644
+ s = dpnp .empty (
645
+ (0 ,),
646
+ dtype = s_type ,
647
+ usm_type = a_usm_type ,
648
+ sycl_queue = a_sycl_queue ,
666
649
)
667
-
668
- k = n # = min(m, n) where m >= n is ensured above
669
650
if compute_uv :
670
651
if full_matrices :
671
- u_shape = (m , m )
672
- vt_shape = (n , n )
673
- jobu = ord ("A" )
674
- jobvt = ord ("A" )
652
+ u_shape = (n ,)
653
+ vt_shape = (m ,)
675
654
else :
676
- u_shape = x .shape
677
- vt_shape = (k , n )
678
- jobu = ord ("S" )
679
- jobvt = ord ("S" )
655
+ u_shape = (n , 0 )
656
+ vt_shape = (0 , m )
657
+
658
+ u = dpnp .eye (
659
+ * u_shape ,
660
+ dtype = uv_type ,
661
+ usm_type = a_usm_type ,
662
+ sycl_queue = a_sycl_queue ,
663
+ )
664
+ vt = dpnp .eye (
665
+ * vt_shape ,
666
+ dtype = uv_type ,
667
+ usm_type = a_usm_type ,
668
+ sycl_queue = a_sycl_queue ,
669
+ )
670
+ return u , s , vt
680
671
else :
681
- u_shape = vt_shape = ()
682
- jobu = ord ("N" )
683
- jobvt = ord ("N" )
672
+ return s
684
673
685
- u_h = dpnp .empty (
686
- u_shape ,
687
- dtype = uv_type ,
688
- usm_type = a_usm_type ,
689
- sycl_queue = a_sycl_queue ,
690
- )
691
- vt_h = dpnp .empty (
692
- vt_shape ,
693
- dtype = uv_type ,
694
- usm_type = a_usm_type ,
695
- sycl_queue = a_sycl_queue ,
696
- )
697
- s_h = dpnp .empty (
698
- k , dtype = s_type , usm_type = a_usm_type , sycl_queue = a_sycl_queue
699
- )
674
+ # `a` must be copied because gesvd destroys the input matrix
675
+ # `a` must be traspotted if m < n
676
+ if m >= n :
677
+ x = a
678
+ a_h = dpnp .empty_like (a , order = "C" , dtype = uv_type )
679
+ trans_flag = False
680
+ else :
681
+ m , n = a .shape
682
+ x = a .transpose ()
683
+ a_h = dpnp .empty_like (x , order = "C" , dtype = uv_type )
684
+ trans_flag = True
700
685
701
- ht_lapack_ev , _ = li ._gesvd (
702
- a_sycl_queue ,
703
- jobu ,
704
- jobvt ,
705
- m ,
706
- n ,
707
- a_h .get_array (),
708
- s_h .get_array (),
709
- u_h .get_array (),
710
- vt_h .get_array (),
711
- [a_copy_ev ],
712
- )
686
+ a_usm_arr = dpnp .get_usm_ndarray (x )
713
687
714
- ht_lapack_ev .wait ()
715
- a_ht_copy_ev .wait ()
688
+ # use DPCTL tensor function to fill the сopy of the input array
689
+ # from the input array
690
+ a_ht_copy_ev , a_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
691
+ src = a_usm_arr , dst = a_h .get_array (), sycl_queue = a_sycl_queue
692
+ )
716
693
717
- if compute_uv :
718
- if trans_flag :
719
- return u_h .transpose (), s_h , vt_h .transpose ()
720
- else :
721
- return vt_h , s_h , u_h
694
+ k = n # = min(m, n) where m >= n is ensured above
695
+ if compute_uv :
696
+ if full_matrices :
697
+ u_shape = (m , m )
698
+ vt_shape = (n , n )
699
+ jobu = ord ("A" )
700
+ jobvt = ord ("A" )
722
701
else :
723
- return s_h
702
+ u_shape = x .shape
703
+ vt_shape = (k , n )
704
+ jobu = ord ("S" )
705
+ jobvt = ord ("S" )
706
+ else :
707
+ u_shape = vt_shape = ()
708
+ jobu = ord ("N" )
709
+ jobvt = ord ("N" )
710
+
711
+ u_h = dpnp .empty (
712
+ u_shape ,
713
+ dtype = uv_type ,
714
+ usm_type = a_usm_type ,
715
+ sycl_queue = a_sycl_queue ,
716
+ )
717
+ vt_h = dpnp .empty (
718
+ vt_shape ,
719
+ dtype = uv_type ,
720
+ usm_type = a_usm_type ,
721
+ sycl_queue = a_sycl_queue ,
722
+ )
723
+ s_h = dpnp .empty (
724
+ k , dtype = s_type , usm_type = a_usm_type , sycl_queue = a_sycl_queue
725
+ )
726
+
727
+ ht_lapack_ev , _ = li ._gesvd (
728
+ a_sycl_queue ,
729
+ jobu ,
730
+ jobvt ,
731
+ m ,
732
+ n ,
733
+ a_h .get_array (),
734
+ s_h .get_array (),
735
+ u_h .get_array (),
736
+ vt_h .get_array (),
737
+ [a_copy_ev ],
738
+ )
739
+
740
+ ht_lapack_ev .wait ()
741
+ a_ht_copy_ev .wait ()
742
+
743
+ if compute_uv :
744
+ if trans_flag :
745
+ return u_h .transpose (), s_h , vt_h .transpose ()
746
+ else :
747
+ return vt_h , s_h , u_h
748
+ else :
749
+ return s_h
0 commit comments