@@ -1065,13 +1065,25 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
1065
1065
u_matrices = [None ] * batch_size
1066
1066
s_matrices = [None ] * batch_size
1067
1067
vt_matrices = [None ] * batch_size
1068
+ a_ht_copy_ev = [None ] * batch_size
1069
+ ht_lapack_ev = [None ] * batch_size
1068
1070
for i in range (batch_size ):
1069
1071
if compute_uv :
1070
- u_matrices [i ], s_matrices [i ], vt_matrices [i ] = dpnp_svd (
1071
- a [i ], full_matrices , compute_uv = True
1072
- )
1072
+ (
1073
+ u_matrices [i ],
1074
+ s_matrices [i ],
1075
+ vt_matrices [i ],
1076
+ ht_lapack_ev [i ],
1077
+ a_ht_copy_ev [i ],
1078
+ ) = dpnp_svd (a [i ], full_matrices , compute_uv = True , batch_call = True )
1073
1079
else :
1074
- s_matrices [i ] = dpnp_svd (a [i ], full_matrices , compute_uv = False )
1080
+ s_matrices [i ], ht_lapack_ev [i ], a_ht_copy_ev [i ] = dpnp_svd (
1081
+ a [i ], full_matrices , compute_uv = False , batch_call = True
1082
+ )
1083
+
1084
+ for i in range (batch_size ):
1085
+ ht_lapack_ev [i ].wait ()
1086
+ a_ht_copy_ev [i ].wait ()
1075
1087
1076
1088
out_s = dpnp .array (s_matrices )
1077
1089
if reshape :
@@ -1092,9 +1104,11 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
1092
1104
return out_s
1093
1105
1094
1106
1095
- def dpnp_svd (a , full_matrices = True , compute_uv = True , hermitian = False ):
1107
+ def dpnp_svd (
1108
+ a , full_matrices = True , compute_uv = True , hermitian = False , batch_call = False
1109
+ ):
1096
1110
"""
1097
- dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False)
1111
+ dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False, batch_call=False )
1098
1112
1099
1113
Return the singular value decomposition (SVD).
1100
1114
@@ -1226,6 +1240,12 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1226
1240
[a_copy_ev ],
1227
1241
)
1228
1242
1243
+ if batch_call :
1244
+ if compute_uv :
1245
+ return u_h , s_h , vt_h , ht_lapack_ev , a_ht_copy_ev
1246
+ else :
1247
+ return s_h , ht_lapack_ev , a_ht_copy_ev
1248
+
1229
1249
ht_lapack_ev .wait ()
1230
1250
a_ht_copy_ev .wait ()
1231
1251
0 commit comments