Skip to content

Commit 6eab947

Browse files
Support 4d and more array for dpnp_eigh
1 parent 7b7bea9 commit 6eab947

File tree

1 file changed

+18
-34
lines changed

1 file changed

+18
-34
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -759,10 +759,6 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
759759
return w, v
760760
return w
761761

762-
a_sycl_queue = a.sycl_queue
763-
a_order = "C" if a.flags.c_contiguous else "F"
764-
a_usm_arr = dpnp.get_usm_ndarray(a)
765-
766762
# `eigen_mode` can be either "N" or "V", specifying the computation mode
767763
# for OneMKL LAPACK `syevd` and `heevd` routines.
768764
# "V" (default) means both eigenvectors and eigenvalues will be calculated
@@ -776,42 +772,27 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
776772
"_heevd" if dpnp.issubdtype(v_type, dpnp.complexfloating) else "_syevd"
777773
)
778774

775+
a_sycl_queue = a.sycl_queue
776+
a_order = "C" if a.flags.c_contiguous else "F"
777+
779778
if a.ndim > 2:
779+
orig_shape = a.shape
780+
# get 3d input array by reshape
781+
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
782+
a_usm_arr = dpnp.get_usm_ndarray(a)
783+
784+
# allocate a memory for dpnp array of eigenvalues
780785
w = dpnp.empty_like(
781786
a,
782-
shape=a.shape[:-1],
787+
shape=orig_shape[:-1],
783788
dtype=w_type,
784789
)
790+
w_orig_shape = w.shape
791+
# get 2d dpnp array with eigenvalues by reshape
792+
w = w.reshape(-1, w_orig_shape[-1])
785793

786794
# need to loop over the 1st dimension to get eigenvalues and eigenvectors of 3d matrix A
787795
batch_size = a.shape[0]
788-
if batch_size == 0:
789-
return (
790-
(w, dpnp.empty_like(a, dtype=v_type))
791-
if eigen_mode == "V"
792-
else w
793-
)
794-
795-
# When `eigen_mode == "N"` (jobz == 0), OneMKL LAPACK does not overwrite the input array.
796-
# If the input array 'a' is already F-contiguous and matches the target data type,
797-
# we can avoid unnecessary memory allocation and data copying.
798-
if eigen_mode == "N" and a_order == "F" and a.dtype == v_type:
799-
ht_list_ev = [None] * batch_size
800-
for i in range(batch_size):
801-
# call LAPACK extension function to get eigenvalues of a portion of matrix A
802-
ht_list_ev[i], _ = getattr(li, lapack_func)(
803-
a_sycl_queue,
804-
jobz,
805-
uplo,
806-
a[i].get_array(),
807-
w[i].get_array(),
808-
depends=[],
809-
)
810-
811-
dpctl.SyclEvent.wait_for(ht_list_ev)
812-
813-
return w
814-
815796
eig_vecs = [None] * batch_size
816797
ht_list_ev = [None] * batch_size * 2
817798
for i in range(batch_size):
@@ -838,15 +819,18 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
838819

839820
dpctl.SyclEvent.wait_for(ht_list_ev)
840821

822+
w = w.reshape(w_orig_shape)
823+
841824
if eigen_mode == "V":
842825
# combine the list of eigenvectors into a single array
843-
v = dpnp.array(eig_vecs, order=a_order)
826+
v = dpnp.array(eig_vecs, order=a_order).reshape(orig_shape)
844827
return w, v
845828
return w
846829

847830
else:
831+
a_usm_arr = dpnp.get_usm_ndarray(a)
848832
ht_list_ev = []
849-
copy_ev = None
833+
copy_ev = dpctl.SyclEvent()
850834

851835
# When `eigen_mode == "N"` (jobz == 0), OneMKL LAPACK does not overwrite the input array.
852836
# If the input array 'a' is already F-contiguous and matches the target data type,

0 commit comments

Comments
 (0)