Skip to content

Commit 45b4f5c

Browse files
Merge 10df422 into 92faa20
2 parents 92faa20 + 10df422 commit 45b4f5c

File tree

9 files changed

+774
-217
lines changed

9 files changed

+774
-217
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ set(_module_src
3232
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/gesv_batch.cpp
3434
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
35+
${CMAKE_CURRENT_SOURCE_DIR}/gesvd_batch.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3637
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3738
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 25 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,20 @@
2626
#include <pybind11/pybind11.h>
2727

2828
// dpctl tensor headers
29-
#include "utils/memory_overlap.hpp"
3029
#include "utils/type_utils.hpp"
3130

31+
#include "common_helpers.hpp"
3232
#include "gesvd.hpp"
33+
#include "gesvd_common_utils.hpp"
3334
#include "types_matrix.hpp"
3435

35-
#include "dpnp_utils.hpp"
36-
3736
namespace dpnp::extensions::lapack
3837
{
3938
namespace mkl_lapack = oneapi::mkl::lapack;
4039
namespace py = pybind11;
4140
namespace type_utils = dpctl::tensor::type_utils;
4241

43-
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
42+
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue &,
4443
const oneapi::mkl::jobsvd,
4544
const oneapi::mkl::jobsvd,
4645
const std::int64_t,
@@ -58,26 +57,8 @@ typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
5857
static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
5958
[dpctl_td_ns::num_types];
6059

61-
// Converts a given character code (ord) to the corresponding
62-
// oneapi::mkl::jobsvd enumeration value
63-
static oneapi::mkl::jobsvd process_job(std::int8_t job_val)
64-
{
65-
switch (job_val) {
66-
case 'A':
67-
return oneapi::mkl::jobsvd::vectors;
68-
case 'S':
69-
return oneapi::mkl::jobsvd::somevec;
70-
case 'O':
71-
return oneapi::mkl::jobsvd::vectorsina;
72-
case 'N':
73-
return oneapi::mkl::jobsvd::novec;
74-
default:
75-
throw std::invalid_argument("Unknown value for job");
76-
}
77-
}
78-
7960
template <typename T, typename RealT>
80-
static sycl::event gesvd_impl(sycl::queue exec_q,
61+
static sycl::event gesvd_impl(sycl::queue &exec_q,
8162
const oneapi::mkl::jobsvd jobu,
8263
const oneapi::mkl::jobsvd jobvt,
8364
const std::int64_t m,
@@ -102,16 +83,14 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
10283

10384
const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
10485
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
105-
T *scratchpad = nullptr;
86+
87+
T *scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);
10688

10789
std::stringstream error_msg;
108-
std::int64_t info = 0;
10990
bool is_exception_caught = false;
11091

11192
sycl::event gesvd_event;
11293
try {
113-
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
114-
11594
gesvd_event = mkl_lapack::gesvd(
11695
exec_q,
11796
jobu, // Character specifying how to compute the matrix U:
@@ -138,26 +117,7 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
138117
scratchpad_size, depends);
139118
} catch (mkl_lapack::exception const &e) {
140119
is_exception_caught = true;
141-
info = e.info();
142-
if (info < 0) {
143-
error_msg << "Parameter number " << -info
144-
<< " had an illegal value.";
145-
}
146-
else if (info == scratchpad_size && e.detail() != 0) {
147-
error_msg
148-
<< "Insufficient scratchpad size. Required size is at least "
149-
<< e.detail();
150-
}
151-
else if (info > 0) {
152-
error_msg << "The algorithm computing SVD failed to converge; "
153-
<< info << " off-diagonal elements of an intermediate "
154-
<< "bidiagonal form did not converge to zero.\n";
155-
}
156-
else {
157-
error_msg << "Unexpected MKL exception caught during gesvd() "
158-
"call:\nreason: "
159-
<< e.what() << "\ninfo: " << e.info();
160-
}
120+
gesvd_utils::handle_lapack_exc(scratchpad_size, e, error_msg);
161121
} catch (sycl::exception const &e) {
162122
is_exception_caught = true;
163123
error_msg << "Unexpected SYCL exception caught during gesvd() call:\n"
@@ -182,7 +142,7 @@ static sycl::event gesvd_impl(sycl::queue exec_q,
182142
}
183143

184144
std::pair<sycl::event, sycl::event>
185-
gesvd(sycl::queue exec_q,
145+
gesvd(sycl::queue &exec_q,
186146
const std::int8_t jobu_val,
187147
const std::int8_t jobvt_val,
188148
dpctl::tensor::usm_ndarray a_array,
@@ -191,103 +151,26 @@ std::pair<sycl::event, sycl::event>
191151
dpctl::tensor::usm_ndarray out_vt,
192152
const std::vector<sycl::event> &depends)
193153
{
194-
const int a_array_nd = a_array.get_ndim();
195-
const int out_u_array_nd = out_u.get_ndim();
196-
const int out_s_array_nd = out_s.get_ndim();
197-
const int out_vt_array_nd = out_vt.get_ndim();
198-
199-
if (a_array_nd != 2) {
200-
throw py::value_error(
201-
"The input array has ndim=" + std::to_string(a_array_nd) +
202-
", but a 2-dimensional array is expected.");
203-
}
204-
205-
if (out_s_array_nd != 1) {
206-
throw py::value_error("The output array of singular values has ndim=" +
207-
std::to_string(out_s_array_nd) +
208-
", but a 1-dimensional array is expected.");
209-
}
210-
211-
if (jobu_val == 'N' && jobvt_val == 'N') {
212-
if (out_u_array_nd != 0) {
213-
throw py::value_error(
214-
"The output array of the left singular vectors has ndim=" +
215-
std::to_string(out_u_array_nd) +
216-
", but it is not used and should have ndim=0.");
217-
}
218-
if (out_vt_array_nd != 0) {
219-
throw py::value_error(
220-
"The output array of the right singular vectors has ndim=" +
221-
std::to_string(out_vt_array_nd) +
222-
", but it is not used and should have ndim=0.");
223-
}
224-
}
225-
else {
226-
if (out_u_array_nd != 2) {
227-
throw py::value_error(
228-
"The output array of the left singular vectors has ndim=" +
229-
std::to_string(out_u_array_nd) +
230-
", but a 2-dimensional array is expected.");
231-
}
232-
if (out_vt_array_nd != 2) {
233-
throw py::value_error(
234-
"The output array of the right singular vectors has ndim=" +
235-
std::to_string(out_vt_array_nd) +
236-
", but a 2-dimensional array is expected.");
237-
}
238-
}
239-
240-
// check compatibility of execution queue and allocation queue
241-
if (!dpctl::utils::queues_are_compatible(
242-
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
243-
out_vt.get_queue()}))
244-
{
245-
throw std::runtime_error(
246-
"USM allocations are not compatible with the execution queue.");
247-
}
248-
249-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
250-
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
251-
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
252-
overlap(out_s, out_vt) || overlap(out_u, out_vt))
253-
{
254-
throw py::value_error("Arrays have overlapping segments of memory");
255-
}
256-
257-
bool is_a_array_f_contig = a_array.is_f_contiguous();
258-
if (!is_a_array_f_contig) {
259-
throw py::value_error("The input array must be F-contiguous");
260-
}
261-
262-
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
263-
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
264-
265-
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
266-
throw py::value_error("The output arrays of the left and right "
267-
"singular vectors must be F-contiguous");
268-
}
269-
270-
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
271-
bool is_out_s_array_f_contig = out_s.is_f_contiguous();
272-
273-
if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
274-
throw py::value_error("The output array of singular values "
275-
"must be contiguous");
276-
}
154+
constexpr int expected_a_u_vt_ndim = 2;
155+
constexpr int expected_s_ndim = 1;
156+
157+
gesvd_utils::common_gesvd_checks(exec_q, a_array, out_s, out_u, out_vt,
158+
jobu_val, jobvt_val, expected_a_u_vt_ndim,
159+
expected_s_ndim);
160+
161+
// // Ensure `m` and 'n' are non-zero, otherwise return empty
162+
// // events
163+
// if (gesvd_utils::check_zeros_shape_gesvd(a_array, out_s, out_u, out_vt,
164+
// jobu_val, jobvt_val))
165+
// {
166+
// // nothing to do
167+
// return std::make_pair(sycl::event(), sycl::event());
168+
// }
277169

278170
auto array_types = dpctl_td_ns::usm_ndarray_types();
279171
int a_array_type_id =
280172
array_types.typenum_to_lookup_id(a_array.get_typenum());
281-
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
282173
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
283-
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
284-
285-
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
286-
throw py::type_error(
287-
"Input array, output left singular vectors array, "
288-
"and outpuy right singular vectors array must have "
289-
"the same data type");
290-
}
291174

292175
gesvd_impl_fn_ptr_t gesvd_fn =
293176
gesvd_dispatch_table[a_array_type_id][out_s_type_id];
@@ -311,8 +194,8 @@ std::pair<sycl::event, sycl::event>
311194
const std::int64_t ldvt =
312195
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);
313196

314-
const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
315-
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);
197+
const oneapi::mkl::jobsvd jobu = gesvd_utils::process_job(jobu_val);
198+
const oneapi::mkl::jobsvd jobvt = gesvd_utils::process_job(jobvt_val);
316199

317200
std::vector<sycl::event> host_task_events;
318201
sycl::event gesvd_ev =

dpnp/backend/extensions/lapack/gesvd.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
namespace dpnp::extensions::lapack
3434
{
3535
extern std::pair<sycl::event, sycl::event>
36-
gesvd(sycl::queue exec_q,
36+
gesvd(sycl::queue &exec_q,
3737
const std::int8_t jobu_val,
3838
const std::int8_t jobvt_val,
3939
dpctl::tensor::usm_ndarray a_array,
@@ -42,5 +42,16 @@ extern std::pair<sycl::event, sycl::event>
4242
dpctl::tensor::usm_ndarray out_vt,
4343
const std::vector<sycl::event> &depends);
4444

45+
extern std::pair<sycl::event, sycl::event>
46+
gesvd_batch(sycl::queue &exec_q,
47+
const std::int8_t jobu_val,
48+
const std::int8_t jobvt_val,
49+
dpctl::tensor::usm_ndarray a_array,
50+
dpctl::tensor::usm_ndarray out_s,
51+
dpctl::tensor::usm_ndarray out_u,
52+
dpctl::tensor::usm_ndarray out_vt,
53+
const std::vector<sycl::event> &depends);
54+
4555
extern void init_gesvd_dispatch_table(void);
56+
extern void init_gesvd_batch_dispatch_table(void);
4657
} // namespace dpnp::extensions::lapack

0 commit comments

Comments
 (0)