Skip to content

Commit c7770fd

Browse files
Update dpnp.linalg.svd() function (#1604)
* Draft commit of dpnp.linalg.svd impl * Pass empty arrays if compute_uv=False * Add logic for the input array n < m * Add a new cupy test_decomposition * Rename gesvd input parameters * Correspondence of passed parameters to gesvd signature * Correct initialization of result variables in dpnp_svd * Update test_decomposition * Add implementation of _dpnp_svd_batch * Add test_decomposition to the scope of public CI * Improve error handling for mkl_lapack::gesvd function * Declate detail variable * Use a_usm_type and a_sycl_queue variables * Add additional checks for gesvd function * Remove old dpnp_svd backend * Refresh test_svd in test_linalg * Add detailed comments for gesvd arguments * gesvd returns pair of events and uses dpctl.utils.keep_args_alive * Keep a lexicographical order * Update docstrings for svd * Add test_svd to test_usm_type * Add a new impl to get s_type * Add a description for _stacked_identity * Simplify dpnp_svd_batch * Update tests for dpnp.linalg.svd * Add hermitian argument support * Add test_svd_hermitian * Update svd docstrings * Tune tolerance * Update test_svd_errors * Update _common_type and _common_inexact_type * Remove passing n and m parameteres to _gesvd * Simplify results return logic for dpnp_svd_batch * Update condition and random files in cupy/testing to use fix_random and repeat decorators * Rename cupy/testing/condition.py to .../_condition.py * Use self._tol in TestSvd * Update gesvd error handler * dpnp_svd works with F contiguous arrays * Add additional checks for output arrays * Impl parallel calculation in dpnp_svd_batch * Skip using @_condition.repeat in cupy tests * Add additional checks for output arrays * Update docstrings for svd * Use dpctl.SyclEvent.wait_for in dpnp_svd_batch * Add TODO : matching the order of returned arrays * Skip cupy tests on windows * Rename condition to _condition * Set setUpClass to skip cupy tests on cpu
1 parent 7180fe5 commit c7770fd

File tree

20 files changed

+1425
-326
lines changed

20 files changed

+1425
-326
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _lapack_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "gesvd.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
50+
const oneapi::mkl::jobsvd,
51+
const oneapi::mkl::jobsvd,
52+
const std::int64_t,
53+
const std::int64_t,
54+
char *,
55+
const std::int64_t,
56+
char *,
57+
char *,
58+
const std::int64_t,
59+
char *,
60+
const std::int64_t,
61+
std::vector<sycl::event> &,
62+
const std::vector<sycl::event> &);
63+
64+
static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
65+
[dpctl_td_ns::num_types];
66+
67+
// Converts a given character code (ord) to the corresponding
68+
// oneapi::mkl::jobsvd enumeration value
69+
static oneapi::mkl::jobsvd process_job(std::int8_t job_val)
70+
{
71+
switch (job_val) {
72+
case 'A':
73+
return oneapi::mkl::jobsvd::vectors;
74+
case 'S':
75+
return oneapi::mkl::jobsvd::somevec;
76+
case 'O':
77+
return oneapi::mkl::jobsvd::vectorsina;
78+
case 'N':
79+
return oneapi::mkl::jobsvd::novec;
80+
default:
81+
throw std::invalid_argument("Unknown value for job");
82+
}
83+
}
84+
85+
template <typename T, typename RealT>
86+
static sycl::event gesvd_impl(sycl::queue exec_q,
87+
const oneapi::mkl::jobsvd jobu,
88+
const oneapi::mkl::jobsvd jobvt,
89+
const std::int64_t m,
90+
const std::int64_t n,
91+
char *in_a,
92+
const std::int64_t lda,
93+
char *out_s,
94+
char *out_u,
95+
const std::int64_t ldu,
96+
char *out_vt,
97+
const std::int64_t ldvt,
98+
std::vector<sycl::event> &host_task_events,
99+
const std::vector<sycl::event> &depends)
100+
{
101+
type_utils::validate_type_for_device<T>(exec_q);
102+
type_utils::validate_type_for_device<RealT>(exec_q);
103+
104+
T *a = reinterpret_cast<T *>(in_a);
105+
RealT *s = reinterpret_cast<RealT *>(out_s);
106+
T *u = reinterpret_cast<T *>(out_u);
107+
T *vt = reinterpret_cast<T *>(out_vt);
108+
109+
const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
110+
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
111+
T *scratchpad = nullptr;
112+
113+
std::stringstream error_msg;
114+
std::int64_t info = 0;
115+
bool is_exception_caught = false;
116+
117+
sycl::event gesvd_event;
118+
try {
119+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
120+
121+
gesvd_event = mkl_lapack::gesvd(
122+
exec_q,
123+
jobu, // Character specifying how to compute the matrix U:
124+
// 'A' computes all columns of U,
125+
// 'S' computes the first min(m,n) columns of U,
126+
// 'O' overwrites A with the columns of U,
127+
// 'N' does not compute U.
128+
jobvt, // Character specifying how to compute the matrix VT:
129+
// 'A' computes all rows of VT,
130+
// 'S' computes the first min(m,n) rows of VT,
131+
// 'O' overwrites A with the rows of VT,
132+
// 'N' does not compute VT.
133+
m, // The number of rows in the input matrix A (0 <= m).
134+
n, // The number of columns in the input matrix A (0 <= n).
135+
a, // Pointer to the input matrix A of size (m x n).
136+
lda, // The leading dimension of A, must be at least max(1, m).
137+
s, // Pointer to the array containing the singular values.
138+
u, // Pointer to the matrix U in the singular value decomposition.
139+
ldu, // The leading dimension of U, must be at least max(1, m).
140+
vt, // Pointer to the matrix VT in the singular value decomposition.
141+
ldvt, // The leading dimension of VT, must be at least max(1, n).
142+
scratchpad, // Pointer to scratchpad memory to be used by MKL
143+
// routine for storing intermediate results.
144+
scratchpad_size, depends);
145+
} catch (mkl_lapack::exception const &e) {
146+
is_exception_caught = true;
147+
info = e.info();
148+
if (info < 0) {
149+
error_msg << "Parameter number " << -info
150+
<< " had an illegal value.";
151+
}
152+
else if (info == scratchpad_size && e.detail() != 0) {
153+
error_msg
154+
<< "Insufficient scratchpad size. Required size is at least "
155+
<< e.detail();
156+
}
157+
else if (info > 0) {
158+
error_msg << "The algorithm computing SVD failed to converge; "
159+
<< info << " off-diagonal elements of an intermediate "
160+
<< "bidiagonal form did not converge to zero.\n";
161+
}
162+
else {
163+
error_msg << "Unexpected MKL exception caught during gesvd() "
164+
"call:\nreason: "
165+
<< e.what() << "\ninfo: " << e.info();
166+
}
167+
} catch (sycl::exception const &e) {
168+
is_exception_caught = true;
169+
error_msg << "Unexpected SYCL exception caught during gesvd() call:\n"
170+
<< e.what();
171+
}
172+
173+
if (is_exception_caught) // an unexpected error occurs
174+
{
175+
if (scratchpad != nullptr) {
176+
sycl::free(scratchpad, exec_q);
177+
}
178+
throw std::runtime_error(error_msg.str());
179+
}
180+
181+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
182+
cgh.depends_on(gesvd_event);
183+
auto ctx = exec_q.get_context();
184+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
185+
});
186+
host_task_events.push_back(clean_up_event);
187+
return gesvd_event;
188+
}
189+
190+
std::pair<sycl::event, sycl::event>
191+
gesvd(sycl::queue exec_q,
192+
const std::int8_t jobu_val,
193+
const std::int8_t jobvt_val,
194+
dpctl::tensor::usm_ndarray a_array,
195+
dpctl::tensor::usm_ndarray out_s,
196+
dpctl::tensor::usm_ndarray out_u,
197+
dpctl::tensor::usm_ndarray out_vt,
198+
const std::vector<sycl::event> &depends)
199+
{
200+
const int a_array_nd = a_array.get_ndim();
201+
const int out_u_array_nd = out_u.get_ndim();
202+
const int out_s_array_nd = out_s.get_ndim();
203+
const int out_vt_array_nd = out_vt.get_ndim();
204+
205+
if (a_array_nd != 2) {
206+
throw py::value_error(
207+
"The input array has ndim=" + std::to_string(a_array_nd) +
208+
", but a 2-dimensional array is expected.");
209+
}
210+
211+
if (out_s_array_nd != 1) {
212+
throw py::value_error("The output array of singular values has ndim=" +
213+
std::to_string(out_s_array_nd) +
214+
", but a 1-dimensional array is expected.");
215+
}
216+
217+
if (jobu_val == 'N' && jobvt_val == 'N') {
218+
if (out_u_array_nd != 0) {
219+
throw py::value_error(
220+
"The output array of the left singular vectors has ndim=" +
221+
std::to_string(out_u_array_nd) +
222+
", but it is not used and should have ndim=0.");
223+
}
224+
if (out_vt_array_nd != 0) {
225+
throw py::value_error(
226+
"The output array of the right singular vectors has ndim=" +
227+
std::to_string(out_vt_array_nd) +
228+
", but it is not used and should have ndim=0.");
229+
}
230+
}
231+
else {
232+
if (out_u_array_nd != 2) {
233+
throw py::value_error(
234+
"The output array of the left singular vectors has ndim=" +
235+
std::to_string(out_u_array_nd) +
236+
", but a 2-dimensional array is expected.");
237+
}
238+
if (out_vt_array_nd != 2) {
239+
throw py::value_error(
240+
"The output array of the right singular vectors has ndim=" +
241+
std::to_string(out_vt_array_nd) +
242+
", but a 2-dimensional array is expected.");
243+
}
244+
}
245+
246+
// check compatibility of execution queue and allocation queue
247+
if (!dpctl::utils::queues_are_compatible(
248+
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
249+
out_vt.get_queue()}))
250+
{
251+
throw std::runtime_error(
252+
"USM allocations are not compatible with the execution queue.");
253+
}
254+
255+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
256+
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
257+
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
258+
overlap(out_s, out_vt) || overlap(out_u, out_vt))
259+
{
260+
throw py::value_error("Arrays have overlapping segments of memory");
261+
}
262+
263+
bool is_a_array_f_contig = a_array.is_f_contiguous();
264+
if (!is_a_array_f_contig) {
265+
throw py::value_error("The input array must be F-contiguous");
266+
}
267+
268+
bool is_out_u_array_f_contig = out_u.is_f_contiguous();
269+
bool is_out_vt_array_f_contig = out_vt.is_f_contiguous();
270+
271+
if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) {
272+
throw py::value_error("The output arrays of the left and right "
273+
"singular vectors must be F-contiguous");
274+
}
275+
276+
bool is_out_s_array_c_contig = out_s.is_c_contiguous();
277+
bool is_out_s_array_f_contig = out_s.is_f_contiguous();
278+
279+
if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) {
280+
throw py::value_error("The output array of singular values "
281+
"must be contiguous");
282+
}
283+
284+
auto array_types = dpctl_td_ns::usm_ndarray_types();
285+
int a_array_type_id =
286+
array_types.typenum_to_lookup_id(a_array.get_typenum());
287+
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
288+
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
289+
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
290+
291+
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
292+
throw py::type_error(
293+
"Input array, output left singular vectors array, "
294+
"and outpuy right singular vectors array must have "
295+
"the same data type");
296+
}
297+
298+
gesvd_impl_fn_ptr_t gesvd_fn =
299+
gesvd_dispatch_table[a_array_type_id][out_s_type_id];
300+
if (gesvd_fn == nullptr) {
301+
throw py::value_error(
302+
"No gesvd implementation is defined for the given pair "
303+
"of array type and output singular values type.");
304+
}
305+
306+
char *a_array_data = a_array.get_data();
307+
char *out_s_data = out_s.get_data();
308+
char *out_u_data = out_u.get_data();
309+
char *out_vt_data = out_vt.get_data();
310+
311+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
312+
const std::int64_t m = a_array_shape[0];
313+
const std::int64_t n = a_array_shape[1];
314+
315+
const std::int64_t lda = std::max<size_t>(1UL, m);
316+
const std::int64_t ldu = std::max<size_t>(1UL, m);
317+
const std::int64_t ldvt =
318+
std::max<std::size_t>(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n);
319+
320+
const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
321+
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);
322+
323+
std::vector<sycl::event> host_task_events;
324+
sycl::event gesvd_ev =
325+
gesvd_fn(exec_q, jobu, jobvt, m, n, a_array_data, lda, out_s_data,
326+
out_u_data, ldu, out_vt_data, ldvt, host_task_events, depends);
327+
328+
sycl::event args_ev = dpctl::utils::keep_args_alive(
329+
exec_q, {a_array, out_s, out_u, out_vt}, host_task_events);
330+
331+
return std::make_pair(args_ev, gesvd_ev);
332+
}
333+
334+
template <typename fnT, typename T, typename RealT>
335+
struct GesvdContigFactory
336+
{
337+
fnT get()
338+
{
339+
if constexpr (types::GesvdTypePairSupportFactory<T, RealT>::is_defined)
340+
{
341+
return gesvd_impl<T, RealT>;
342+
}
343+
else {
344+
return nullptr;
345+
}
346+
}
347+
};
348+
349+
void init_gesvd_dispatch_table(void)
350+
{
351+
dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t, GesvdContigFactory,
352+
dpctl_td_ns::num_types>
353+
contig;
354+
contig.populate_dispatch_table(gesvd_dispatch_table);
355+
}
356+
} // namespace lapack
357+
} // namespace ext
358+
} // namespace backend
359+
} // namespace dpnp

0 commit comments

Comments
 (0)