Skip to content

Commit 347fbe9

Browse files
dpnp_svd works with F contiguous arrays
1 parent 1fcfb5e commit 347fbe9

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ std::pair<sycl::event, sycl::event>
222222
throw py::value_error("Arrays have overlapping segments of memory");
223223
}
224224

225-
bool is_a_array_c_contig = a_array.is_c_contiguous();
226-
if (!is_a_array_c_contig) {
227-
throw py::value_error("The input array must be C-contiguous");
225+
bool is_a_array_f_contig = a_array.is_f_contiguous();
226+
if (!is_a_array_f_contig) {
227+
throw py::value_error("The input array must be F-contiguous");
228228
}
229229

230+
// TODO: add checks for output arrays
231+
230232
auto array_types = dpctl_td_ns::usm_ndarray_types();
231233
int a_array_type_id =
232234
array_types.typenum_to_lookup_id(a_array.get_typenum());
@@ -255,12 +257,15 @@ std::pair<sycl::event, sycl::event>
255257
char *out_vt_data = out_vt.get_data();
256258

257259
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
258-
const std::int64_t n = a_array_shape[0];
259-
const std::int64_t m = a_array_shape[1];
260+
const std::int64_t m = a_array_shape[0];
261+
const std::int64_t n = a_array_shape[1];
260262

261263
const std::int64_t lda = std::max<size_t>(1UL, m);
262264
const std::int64_t ldu = std::max<size_t>(1UL, m);
263-
const std::int64_t ldvt = std::max<size_t>(1UL, n);
265+
const std::int64_t ldvt =
266+
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);
267+
std::cout << "ldvt: " << ldvt << std::endl;
268+
// const std::int64_t ldvt = std::max<size_t>(1UL, n);
264269

265270
const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
266271
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
864864
reshape = True
865865

866866
batch_size = a.shape[0]
867-
n, m = a.shape[-2:]
867+
m, n = a.shape[-2:]
868868

869869
if batch_size == 0:
870870
k = min(m, n)
@@ -876,11 +876,11 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
876876
)
877877
if compute_uv:
878878
if full_matrices:
879-
u_shape = batch_shape_orig + (n, n)
880-
vt_shape = batch_shape_orig + (m, m)
879+
u_shape = batch_shape_orig + (m, m)
880+
vt_shape = batch_shape_orig + (n, n)
881881
else:
882-
u_shape = batch_shape_orig + (n, k)
883-
vt_shape = batch_shape_orig + (k, m)
882+
u_shape = batch_shape_orig + (m, k)
883+
vt_shape = batch_shape_orig + (k, n)
884884

885885
u = dpnp.empty(
886886
u_shape,
@@ -908,27 +908,27 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
908908
if full_matrices:
909909
u = _stacked_identity(
910910
batch_shape_orig,
911-
n,
911+
m,
912912
dtype=uv_type,
913913
usm_type=a_usm_type,
914914
sycl_queue=a_sycl_queue,
915915
)
916916
vt = _stacked_identity(
917917
batch_shape_orig,
918-
m,
918+
n,
919919
dtype=uv_type,
920920
usm_type=a_usm_type,
921921
sycl_queue=a_sycl_queue,
922922
)
923923
else:
924924
u = dpnp.empty(
925-
batch_shape_orig + (n, 0),
925+
batch_shape_orig + (m, 0),
926926
dtype=uv_type,
927927
usm_type=a_usm_type,
928928
sycl_queue=a_sycl_queue,
929929
)
930930
vt = dpnp.empty(
931-
batch_shape_orig + (0, m),
931+
batch_shape_orig + (0, n),
932932
dtype=uv_type,
933933
usm_type=a_usm_type,
934934
sycl_queue=a_sycl_queue,
@@ -942,7 +942,7 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
942942
vt_matrices = [None] * batch_size
943943
for i in range(batch_size):
944944
if compute_uv:
945-
vt_matrices[i], s_matrices[i], u_matrices[i] = dpnp_svd(
945+
u_matrices[i], s_matrices[i], vt_matrices[i] = dpnp_svd(
946946
a[i], full_matrices, compute_uv=True
947947
)
948948
else:
@@ -953,16 +953,16 @@ def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True):
953953
out_s = out_s.reshape(batch_shape_orig + out_s.shape[-1:])
954954

955955
if compute_uv:
956-
out_vt = dpnp.array(vt_matrices)
957-
out_u = dpnp.array(u_matrices)
956+
out_u = dpnp.array(u_matrices, order="F")
957+
out_vt = dpnp.array(vt_matrices, order="F")
958958
if reshape:
959959
return (
960-
out_vt.reshape(batch_shape_orig + out_vt.shape[-2:]),
961-
out_s,
962960
out_u.reshape(batch_shape_orig + out_u.shape[-2:]),
961+
out_s,
962+
out_vt.reshape(batch_shape_orig + out_vt.shape[-2:]),
963963
)
964964
else:
965-
return out_vt, out_s, out_u
965+
return out_u, out_s, out_vt
966966
else:
967967
return out_s
968968

@@ -1010,7 +1010,7 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
10101010

10111011
a_usm_type = a.usm_type
10121012
a_sycl_queue = a.sycl_queue
1013-
n, m = a.shape
1013+
m, n = a.shape
10141014

10151015
if m == 0 or n == 0:
10161016
s = dpnp.empty(
@@ -1021,11 +1021,11 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
10211021
)
10221022
if compute_uv:
10231023
if full_matrices:
1024-
u_shape = (n,)
1025-
vt_shape = (m,)
1024+
u_shape = (m,)
1025+
vt_shape = (n,)
10261026
else:
1027-
u_shape = (n, 0)
1028-
vt_shape = (0, m)
1027+
u_shape = (m, 0)
1028+
vt_shape = (0, n)
10291029

10301030
u = dpnp.eye(
10311031
*u_shape,
@@ -1043,35 +1043,28 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
10431043
else:
10441044
return s
10451045

1046-
# `a` must be transposed if m < n
1047-
if m >= n:
1048-
x = a
1049-
trans_flag = False
1050-
else:
1051-
m, n = a.shape
1052-
x = a.transpose()
1053-
trans_flag = True
1054-
10551046
# `a` must be copied because gesvd destroys the input matrix
1056-
a_h = dpnp.empty_like(x, order="C", dtype=uv_type)
1047+
# oneMKL LAPACK gesvd overwrites `a` and assumes fortran-like array as input.
1048+
# Allocate 'F' order memory for dpnp arrays to comply with these requirements.
1049+
a_h = dpnp.empty_like(a, order="F", dtype=uv_type)
10571050

1058-
a_usm_arr = dpnp.get_usm_ndarray(x)
1051+
a_usm_arr = dpnp.get_usm_ndarray(a)
10591052

10601053
# use DPCTL tensor function to fill the сopy of the input array
10611054
# from the input array
10621055
a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
10631056
src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue
10641057
)
10651058

1066-
k = n # = min(m, n) where m >= n is ensured above
1059+
k = min(m, n)
10671060
if compute_uv:
10681061
if full_matrices:
10691062
u_shape = (m, m)
10701063
vt_shape = (n, n)
10711064
jobu = ord("A")
10721065
jobvt = ord("A")
10731066
else:
1074-
u_shape = x.shape
1067+
u_shape = (m, k)
10751068
vt_shape = (k, n)
10761069
jobu = ord("S")
10771070
jobvt = ord("S")
@@ -1083,12 +1076,14 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
10831076
u_h = dpnp.empty(
10841077
u_shape,
10851078
dtype=uv_type,
1079+
order="F",
10861080
usm_type=a_usm_type,
10871081
sycl_queue=a_sycl_queue,
10881082
)
10891083
vt_h = dpnp.empty(
10901084
vt_shape,
10911085
dtype=uv_type,
1086+
order="F",
10921087
usm_type=a_usm_type,
10931088
sycl_queue=a_sycl_queue,
10941089
)
@@ -1111,9 +1106,6 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
11111106
a_ht_copy_ev.wait()
11121107

11131108
if compute_uv:
1114-
if trans_flag:
1115-
return u_h.transpose(), s_h, vt_h.transpose()
1116-
else:
1117-
return vt_h, s_h, u_h
1109+
return u_h, s_h, vt_h
11181110
else:
11191111
return s_h

0 commit comments

Comments
 (0)