@@ -864,7 +864,7 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
864
864
reshape = True
865
865
866
866
batch_size = a .shape [0 ]
867
- n , m = a .shape [- 2 :]
867
+ m , n = a .shape [- 2 :]
868
868
869
869
if batch_size == 0 :
870
870
k = min (m , n )
@@ -876,11 +876,11 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
876
876
)
877
877
if compute_uv :
878
878
if full_matrices :
879
- u_shape = batch_shape_orig + (n , n )
880
- vt_shape = batch_shape_orig + (m , m )
879
+ u_shape = batch_shape_orig + (m , m )
880
+ vt_shape = batch_shape_orig + (n , n )
881
881
else :
882
- u_shape = batch_shape_orig + (n , k )
883
- vt_shape = batch_shape_orig + (k , m )
882
+ u_shape = batch_shape_orig + (m , k )
883
+ vt_shape = batch_shape_orig + (k , n )
884
884
885
885
u = dpnp .empty (
886
886
u_shape ,
@@ -908,27 +908,27 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
908
908
if full_matrices :
909
909
u = _stacked_identity (
910
910
batch_shape_orig ,
911
- n ,
911
+ m ,
912
912
dtype = uv_type ,
913
913
usm_type = a_usm_type ,
914
914
sycl_queue = a_sycl_queue ,
915
915
)
916
916
vt = _stacked_identity (
917
917
batch_shape_orig ,
918
- m ,
918
+ n ,
919
919
dtype = uv_type ,
920
920
usm_type = a_usm_type ,
921
921
sycl_queue = a_sycl_queue ,
922
922
)
923
923
else :
924
924
u = dpnp .empty (
925
- batch_shape_orig + (n , 0 ),
925
+ batch_shape_orig + (m , 0 ),
926
926
dtype = uv_type ,
927
927
usm_type = a_usm_type ,
928
928
sycl_queue = a_sycl_queue ,
929
929
)
930
930
vt = dpnp .empty (
931
- batch_shape_orig + (0 , m ),
931
+ batch_shape_orig + (0 , n ),
932
932
dtype = uv_type ,
933
933
usm_type = a_usm_type ,
934
934
sycl_queue = a_sycl_queue ,
@@ -942,7 +942,7 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
942
942
vt_matrices = [None ] * batch_size
943
943
for i in range (batch_size ):
944
944
if compute_uv :
945
- vt_matrices [i ], s_matrices [i ], u_matrices [i ] = dpnp_svd (
945
+ u_matrices [i ], s_matrices [i ], vt_matrices [i ] = dpnp_svd (
946
946
a [i ], full_matrices , compute_uv = True
947
947
)
948
948
else :
@@ -953,16 +953,16 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
953
953
out_s = out_s .reshape (batch_shape_orig + out_s .shape [- 1 :])
954
954
955
955
if compute_uv :
956
- out_vt = dpnp .array (vt_matrices )
957
- out_u = dpnp .array (u_matrices )
956
+ out_u = dpnp .array (u_matrices , order = "F" )
957
+ out_vt = dpnp .array (vt_matrices , order = "F" )
958
958
if reshape :
959
959
return (
960
- out_vt .reshape (batch_shape_orig + out_vt .shape [- 2 :]),
961
- out_s ,
962
960
out_u .reshape (batch_shape_orig + out_u .shape [- 2 :]),
961
+ out_s ,
962
+ out_vt .reshape (batch_shape_orig + out_vt .shape [- 2 :]),
963
963
)
964
964
else :
965
- return out_vt , out_s , out_u
965
+ return out_u , out_s , out_vt
966
966
else :
967
967
return out_s
968
968
@@ -1010,7 +1010,7 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1010
1010
1011
1011
a_usm_type = a .usm_type
1012
1012
a_sycl_queue = a .sycl_queue
1013
- n , m = a .shape
1013
+ m , n = a .shape
1014
1014
1015
1015
if m == 0 or n == 0 :
1016
1016
s = dpnp .empty (
@@ -1021,11 +1021,11 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1021
1021
)
1022
1022
if compute_uv :
1023
1023
if full_matrices :
1024
- u_shape = (n ,)
1025
- vt_shape = (m ,)
1024
+ u_shape = (m ,)
1025
+ vt_shape = (n ,)
1026
1026
else :
1027
- u_shape = (n , 0 )
1028
- vt_shape = (0 , m )
1027
+ u_shape = (m , 0 )
1028
+ vt_shape = (0 , n )
1029
1029
1030
1030
u = dpnp .eye (
1031
1031
* u_shape ,
@@ -1043,35 +1043,28 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1043
1043
else :
1044
1044
return s
1045
1045
1046
- # `a` must be transposed if m < n
1047
- if m >= n :
1048
- x = a
1049
- trans_flag = False
1050
- else :
1051
- m , n = a .shape
1052
- x = a .transpose ()
1053
- trans_flag = True
1054
-
1055
1046
# `a` must be copied because gesvd destroys the input matrix
1056
- a_h = dpnp .empty_like (x , order = "C" , dtype = uv_type )
1047
+ # oneMKL LAPACK gesvd overwrites `a` and assumes fortran-like array as input.
1048
+ # Allocate 'F' order memory for dpnp arrays to comply with these requirements.
1049
+ a_h = dpnp .empty_like (a , order = "F" , dtype = uv_type )
1057
1050
1058
- a_usm_arr = dpnp .get_usm_ndarray (x )
1051
+ a_usm_arr = dpnp .get_usm_ndarray (a )
1059
1052
1060
1053
# use DPCTL tensor function to fill the сopy of the input array
1061
1054
# from the input array
1062
1055
a_ht_copy_ev , a_copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
1063
1056
src = a_usm_arr , dst = a_h .get_array (), sycl_queue = a_sycl_queue
1064
1057
)
1065
1058
1066
- k = n # = min(m, n) where m >= n is ensured above
1059
+ k = min (m , n )
1067
1060
if compute_uv :
1068
1061
if full_matrices :
1069
1062
u_shape = (m , m )
1070
1063
vt_shape = (n , n )
1071
1064
jobu = ord ("A" )
1072
1065
jobvt = ord ("A" )
1073
1066
else :
1074
- u_shape = x . shape
1067
+ u_shape = ( m , k )
1075
1068
vt_shape = (k , n )
1076
1069
jobu = ord ("S" )
1077
1070
jobvt = ord ("S" )
@@ -1083,12 +1076,14 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1083
1076
u_h = dpnp .empty (
1084
1077
u_shape ,
1085
1078
dtype = uv_type ,
1079
+ order = "F" ,
1086
1080
usm_type = a_usm_type ,
1087
1081
sycl_queue = a_sycl_queue ,
1088
1082
)
1089
1083
vt_h = dpnp .empty (
1090
1084
vt_shape ,
1091
1085
dtype = uv_type ,
1086
+ order = "F" ,
1092
1087
usm_type = a_usm_type ,
1093
1088
sycl_queue = a_sycl_queue ,
1094
1089
)
@@ -1111,9 +1106,6 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1111
1106
a_ht_copy_ev .wait ()
1112
1107
1113
1108
if compute_uv :
1114
- if trans_flag :
1115
- return u_h .transpose (), s_h , vt_h .transpose ()
1116
- else :
1117
- return vt_h , s_h , u_h
1109
+ return u_h , s_h , vt_h
1118
1110
else :
1119
1111
return s_h
0 commit comments