Skip to content

Commit 31e8bbb

Browse files
Impl parallel calculation in dpnp_svd_batch
1 parent 5a4721f commit 31e8bbb

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,13 +1065,25 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
10651065
u_matrices = [None] * batch_size
10661066
s_matrices = [None] * batch_size
10671067
vt_matrices = [None] * batch_size
1068+
a_ht_copy_ev = [None] * batch_size
1069+
ht_lapack_ev = [None] * batch_size
10681070
for i in range(batch_size):
10691071
if compute_uv:
1070-
u_matrices[i], s_matrices[i], vt_matrices[i] = dpnp_svd(
1071-
a[i], full_matrices, compute_uv=True
1072-
)
1072+
(
1073+
u_matrices[i],
1074+
s_matrices[i],
1075+
vt_matrices[i],
1076+
ht_lapack_ev[i],
1077+
a_ht_copy_ev[i],
1078+
) = dpnp_svd(a[i], full_matrices, compute_uv=True, batch_call=True)
10731079
else:
1074-
s_matrices[i] = dpnp_svd(a[i], full_matrices, compute_uv=False)
1080+
s_matrices[i], ht_lapack_ev[i], a_ht_copy_ev[i] = dpnp_svd(
1081+
a[i], full_matrices, compute_uv=False, batch_call=True
1082+
)
1083+
1084+
for i in range(batch_size):
1085+
ht_lapack_ev[i].wait()
1086+
a_ht_copy_ev[i].wait()
10751087

10761088
out_s = dpnp.array(s_matrices)
10771089
if reshape:
@@ -1092,9 +1104,11 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
10921104
return out_s
10931105

10941106

1095-
def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
1107+
def dpnp_svd(
1108+
a, full_matrices=True, compute_uv=True, hermitian=False, batch_call=False
1109+
):
10961110
"""
1097-
dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False)
1111+
dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False, batch_call=False)
10981112
10991113
Return the singular value decomposition (SVD).
11001114
@@ -1226,6 +1240,12 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
12261240
[a_copy_ev],
12271241
)
12281242

1243+
if batch_call:
1244+
if compute_uv:
1245+
return u_h, s_h, vt_h, ht_lapack_ev, a_ht_copy_ev
1246+
else:
1247+
return s_h, ht_lapack_ev, a_ht_copy_ev
1248+
12291249
ht_lapack_ev.wait()
12301250
a_ht_copy_ev.wait()
12311251

0 commit comments

Comments
 (0)