@@ -759,10 +759,6 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
759
759
return w , v
760
760
return w
761
761
762
- a_sycl_queue = a .sycl_queue
763
- a_order = "C" if a .flags .c_contiguous else "F"
764
- a_usm_arr = dpnp .get_usm_ndarray (a )
765
-
766
762
# `eigen_mode` can be either "N" or "V", specifying the computation mode
767
763
# for OneMKL LAPACK `syevd` and `heevd` routines.
768
764
# "V" (default) means both eigenvectors and eigenvalues will be calculated
@@ -776,42 +772,27 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
776
772
"_heevd" if dpnp .issubdtype (v_type , dpnp .complexfloating ) else "_syevd"
777
773
)
778
774
775
+ a_sycl_queue = a .sycl_queue
776
+ a_order = "C" if a .flags .c_contiguous else "F"
777
+
779
778
if a .ndim > 2 :
779
+ orig_shape = a .shape
780
+ # get 3d input array by reshape
781
+ a = a .reshape (- 1 , orig_shape [- 2 ], orig_shape [- 1 ])
782
+ a_usm_arr = dpnp .get_usm_ndarray (a )
783
+
784
+ # allocate a memory for dpnp array of eigenvalues
780
785
w = dpnp .empty_like (
781
786
a ,
782
- shape = a . shape [:- 1 ],
787
+ shape = orig_shape [:- 1 ],
783
788
dtype = w_type ,
784
789
)
790
+ w_orig_shape = w .shape
791
+ # get 2d dpnp array with eigenvalues by reshape
792
+ w = w .reshape (- 1 , w_orig_shape [- 1 ])
785
793
786
794
# need to loop over the 1st dimension to get eigenvalues and eigenvectors of 3d matrix A
787
795
batch_size = a .shape [0 ]
788
- if batch_size == 0 :
789
- return (
790
- (w , dpnp .empty_like (a , dtype = v_type ))
791
- if eigen_mode == "V"
792
- else w
793
- )
794
-
795
- # When `eigen_mode == "N"` (jobz == 0), OneMKL LAPACK does not overwrite the input array.
796
- # If the input array 'a' is already F-contiguous and matches the target data type,
797
- # we can avoid unnecessary memory allocation and data copying.
798
- if eigen_mode == "N" and a_order == "F" and a .dtype == v_type :
799
- ht_list_ev = [None ] * batch_size
800
- for i in range (batch_size ):
801
- # call LAPACK extension function to get eigenvalues of a portion of matrix A
802
- ht_list_ev [i ], _ = getattr (li , lapack_func )(
803
- a_sycl_queue ,
804
- jobz ,
805
- uplo ,
806
- a [i ].get_array (),
807
- w [i ].get_array (),
808
- depends = [],
809
- )
810
-
811
- dpctl .SyclEvent .wait_for (ht_list_ev )
812
-
813
- return w
814
-
815
796
eig_vecs = [None ] * batch_size
816
797
ht_list_ev = [None ] * batch_size * 2
817
798
for i in range (batch_size ):
@@ -838,15 +819,18 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
838
819
839
820
dpctl .SyclEvent .wait_for (ht_list_ev )
840
821
822
+ w = w .reshape (w_orig_shape )
823
+
841
824
if eigen_mode == "V" :
842
825
# combine the list of eigenvectors into a single array
843
- v = dpnp .array (eig_vecs , order = a_order )
826
+ v = dpnp .array (eig_vecs , order = a_order ). reshape ( orig_shape )
844
827
return w , v
845
828
return w
846
829
847
830
else :
831
+ a_usm_arr = dpnp .get_usm_ndarray (a )
848
832
ht_list_ev = []
849
- copy_ev = None
833
+ copy_ev = dpctl . SyclEvent ()
850
834
851
835
# When `eigen_mode == "N"` (jobz == 0), OneMKL LAPACK does not overwrite the input array.
852
836
# If the input array 'a' is already F-contiguous and matches the target data type,
0 commit comments