Skip to content

Commit dd74a36

Browse files
Add additional checks for output arrays
1 parent 03d04d7 commit dd74a36

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,21 @@ std::pair<sycl::event, sycl::event>
227227
throw py::value_error("The input array must be F-contiguous");
228228
}
229229

230-
// TODO: add checks for output arrays
230+
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
231+
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
232+
233+
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
234+
throw py::value_error("The output arrays of the left and right "
235+
"singular vectors must be F-contiguous");
236+
}
237+
238+
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
239+
bool is_out_s_array_f_contig = out_s.is_f_contiguous();
240+
241+
if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
242+
throw py::value_error("The output array of singular values "
243+
"must be C or F-contiguous");
244+
}
231245

232246
auto array_types = dpctl_td_ns::usm_ndarray_types();
233247
int a_array_type_id =
@@ -264,8 +278,6 @@ std::pair<sycl::event, sycl::event>
264278
const std::int64_t ldu = std::max<size_t>(1UL, m);
265279
const std::int64_t ldvt =
266280
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);
267-
std::cout << "ldvt: " << ldvt << std::endl;
268-
// const std::int64_t ldvt = std::max<size_t>(1UL, n);
269281

270282
const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
271283
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,8 +1168,7 @@ def dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False):
11681168
else:
11691169
return s
11701170

1171-
# `a` must be copied because gesvd destroys the input matrix
1172-
# oneMKL LAPACK gesvd overwrites `a` and assumes fortran-like array as input.
1171+
# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
11731172
# Allocate 'F' order memory for dpnp arrays to comply with these requirements.
11741173
a_h = dpnp.empty_like(a, order="F", dtype=uv_type)
11751174

0 commit comments

Comments
 (0)