Skip to content

Commit 19d0540

Browse files
Implement syevd_batch via syevd call
1 parent 807dc14 commit 19d0540

File tree

6 files changed

+426
-53
lines changed

6 files changed

+426
-53
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ set(_module_src
4141
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
4242
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
4343
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
44+
${CMAKE_CURRENT_SOURCE_DIR}/syevd_batch.cpp
4445
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
4546
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
4647
)

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ void init_dispatch_vectors(void)
6060
lapack_ext::init_orgqr_dispatch_vector();
6161
lapack_ext::init_potrf_batch_dispatch_vector();
6262
lapack_ext::init_potrf_dispatch_vector();
63+
lapack_ext::init_syevd_batch_dispatch_vector();
6364
lapack_ext::init_syevd_dispatch_vector();
6465
lapack_ext::init_ungqr_batch_dispatch_vector();
6566
lapack_ext::init_ungqr_dispatch_vector();
@@ -183,6 +184,14 @@ PYBIND11_MODULE(_lapack_impl, m)
183184
py::arg("eig_vecs"), py::arg("eig_vals"),
184185
py::arg("depends") = py::list());
185186

187+
m.def("_syevd_batch", &lapack_ext::syevd_batch,
188+
"Call `syevd` from OneMKL LAPACK library in a loop to return "
189+
"the eigenvalues and eigenvectors of a batch of real symmetric "
190+
"matrices",
191+
py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"),
192+
py::arg("eig_vecs"), py::arg("eig_vals"),
193+
py::arg("depends") = py::list());
194+
186195
m.def("_ungqr_batch", &lapack_ext::ungqr_batch,
187196
"Call `_ungqr_batch` from OneMKL LAPACK library to return "
188197
"the complex unitary matrices matrix Qi of the QR factorization "

dpnp/backend/extensions/lapack/syevd.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ extern std::pair<sycl::event, sycl::event>
4646
dpctl::tensor::usm_ndarray eig_vals,
4747
const std::vector<sycl::event> &depends = {});
4848

49+
extern std::pair<sycl::event, sycl::event>
50+
syevd_batch(sycl::queue exec_q,
51+
const std::int8_t jobz,
52+
const std::int8_t upper_lower,
53+
dpctl::tensor::usm_ndarray eig_vecs,
54+
dpctl::tensor::usm_ndarray eig_vals,
55+
const std::vector<sycl::event> &depends = {});
56+
4957
extern void init_syevd_dispatch_vector(void);
58+
extern void init_syevd_batch_dispatch_vector(void);
5059
} // namespace lapack
5160
} // namespace ext
5261
} // namespace backend
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023-2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
// #include "copy_and_cast_usm_to_usm.hpp"
32+
33+
#include "syevd.hpp"
34+
#include "types_matrix.hpp"
35+
36+
#include "dpnp_utils.hpp"
37+
38+
namespace dpnp
39+
{
40+
namespace backend
41+
{
42+
namespace ext
43+
{
44+
namespace lapack
45+
{
46+
namespace mkl_lapack = oneapi::mkl::lapack;
47+
namespace py = pybind11;
48+
namespace type_utils = dpctl::tensor::type_utils;
49+
50+
typedef sycl::event (*syevd_batch_impl_fn_ptr_t)(
51+
sycl::queue,
52+
const oneapi::mkl::job,
53+
const oneapi::mkl::uplo,
54+
const std::int64_t,
55+
char *,
56+
char *,
57+
std::vector<sycl::event> &,
58+
const std::vector<sycl::event> &);
59+
60+
static syevd_batch_impl_fn_ptr_t
61+
syevd_batch_dispatch_vector[dpctl_td_ns::num_types];
62+
63+
template <typename T>
64+
static sycl::event syevd_batch_impl(sycl::queue exec_q,
65+
const oneapi::mkl::job jobz,
66+
const oneapi::mkl::uplo upper_lower,
67+
const std::int64_t n,
68+
char *in_a,
69+
char *out_w,
70+
std::vector<sycl::event> &host_task_events,
71+
const std::vector<sycl::event> &depends)
72+
{
73+
type_utils::validate_type_for_device<T>(exec_q);
74+
75+
T *a = reinterpret_cast<T *>(in_a);
76+
T *w = reinterpret_cast<T *>(out_w);
77+
78+
const std::int64_t lda = std::max<size_t>(1UL, n);
79+
const std::int64_t scratchpad_size =
80+
mkl_lapack::syevd_scratchpad_size<T>(exec_q, jobz, upper_lower, n, lda);
81+
T *scratchpad = nullptr;
82+
83+
std::stringstream error_msg;
84+
std::int64_t info = 0;
85+
86+
sycl::event syevd_event;
87+
try {
88+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
89+
90+
syevd_event = mkl_lapack::syevd(
91+
exec_q,
92+
jobz, // 'jobz == job::vec' means eigenvalues and eigenvectors are
93+
// computed.
94+
upper_lower, // 'upper_lower == job::upper' means the upper
95+
// triangular part of A, or the lower triangular
96+
// otherwise
97+
n, // The order of the matrix A (0 <= n)
98+
a, // Pointer to A, size (lda, *), where the 2nd dimension, must be
99+
// at least max(1, n) If 'jobz == job::vec', then on exit it will
100+
// contain the eigenvectors of A
101+
lda, // The leading dimension of a, must be at least max(1, n)
102+
w, // Pointer to array of size at least n, it will contain the
103+
// eigenvalues of A in ascending order
104+
scratchpad, // Pointer to scratchpad memory to be used by MKL
105+
// routine for storing intermediate results
106+
scratchpad_size, depends);
107+
} catch (mkl_lapack::exception const &e) {
108+
error_msg
109+
<< "Unexpected MKL exception caught during syevd() call:\nreason: "
110+
<< e.what() << "\ninfo: " << e.info();
111+
info = e.info();
112+
} catch (sycl::exception const &e) {
113+
error_msg << "Unexpected SYCL exception caught during syevd() call:\n"
114+
<< e.what();
115+
info = -1;
116+
}
117+
118+
if (info != 0) // an unexpected error occurs
119+
{
120+
if (scratchpad != nullptr) {
121+
sycl::free(scratchpad, exec_q);
122+
}
123+
throw std::runtime_error(error_msg.str());
124+
}
125+
126+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
127+
cgh.depends_on(syevd_event);
128+
auto ctx = exec_q.get_context();
129+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
130+
});
131+
host_task_events.push_back(clean_up_event);
132+
return syevd_event;
133+
}
134+
135+
std::pair<sycl::event, sycl::event>
136+
syevd_batch(sycl::queue exec_q,
137+
const std::int8_t jobz,
138+
const std::int8_t upper_lower,
139+
dpctl::tensor::usm_ndarray eig_vecs,
140+
dpctl::tensor::usm_ndarray eig_vals,
141+
const std::vector<sycl::event> &depends)
142+
{
143+
const int eig_vecs_nd = eig_vecs.get_ndim();
144+
const int eig_vals_nd = eig_vals.get_ndim();
145+
146+
if (eig_vecs_nd != 3) {
147+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
148+
" of an output array with eigenvectors");
149+
}
150+
else if (eig_vals_nd != 2) {
151+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
152+
" of an output array with eigenvalues");
153+
}
154+
155+
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
156+
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
157+
158+
if (eig_vecs_shape[1] != eig_vecs_shape[2]) {
159+
throw py::value_error(
160+
"The last two dimensions of 'eig_vecs' must be the same.");
161+
}
162+
else if (eig_vecs_shape[0] != eig_vals_shape[0] ||
163+
eig_vecs_shape[1] != eig_vals_shape[1])
164+
{
165+
throw py::value_error(
166+
"The shape of 'eig_vals' must be (batch_size, n), "
167+
"where batch_size = " +
168+
std::to_string(eig_vecs_shape[0]) +
169+
" and n = " + std::to_string(eig_vecs_shape[1]));
170+
}
171+
172+
size_t src_nelems(1);
173+
174+
for (int i = 0; i < eig_vecs_nd; ++i) {
175+
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
176+
}
177+
178+
if (src_nelems == 0) {
179+
// nothing to do
180+
return std::make_pair(sycl::event(), sycl::event());
181+
}
182+
183+
// check compatibility of execution queue and allocation queue
184+
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
185+
throw py::value_error(
186+
"Execution queue is not compatible with allocation queues");
187+
}
188+
189+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
190+
if (overlap(eig_vecs, eig_vals)) {
191+
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
192+
"overlapping segments of memory");
193+
}
194+
195+
bool is_eig_vecs_c_contig = eig_vecs.is_c_contiguous();
196+
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
197+
if (!is_eig_vecs_c_contig) {
198+
throw py::value_error(
199+
"An array with input matrix / output eigenvectors "
200+
"must be C-contiguous");
201+
}
202+
else if (!is_eig_vals_c_contig) {
203+
throw py::value_error(
204+
"An array with output eigenvalues must be C-contiguous");
205+
}
206+
207+
auto array_types = dpctl_td_ns::usm_ndarray_types();
208+
int eig_vecs_type_id =
209+
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
210+
int eig_vals_type_id =
211+
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
212+
213+
if (eig_vecs_type_id != eig_vals_type_id) {
214+
throw py::value_error(
215+
"Types of eigenvectors and eigenvalues are mismatched");
216+
}
217+
218+
syevd_batch_impl_fn_ptr_t syevd_batch_fn =
219+
syevd_batch_dispatch_vector[eig_vecs_type_id];
220+
if (syevd_batch_fn == nullptr) {
221+
throw py::value_error("No syevd implementation defined for a type of "
222+
"eigenvectors and eigenvalues");
223+
}
224+
225+
char *eig_vecs_data = eig_vecs.get_data();
226+
char *eig_vals_data = eig_vals.get_data();
227+
228+
const std::int64_t batch_size = eig_vecs_shape[0];
229+
const std::int64_t n = eig_vecs_shape[1];
230+
int elemsize = eig_vecs.get_elemsize();
231+
232+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
233+
const oneapi::mkl::uplo uplo_val =
234+
static_cast<oneapi::mkl::uplo>(upper_lower);
235+
236+
std::vector<sycl::event> host_task_events;
237+
238+
for (std::int64_t i = 0; i < batch_size; ++i) {
239+
char *eig_vecs_batch = eig_vecs_data + i * n * n * elemsize;
240+
char *eig_vals_batch = eig_vals_data + i * n * elemsize;
241+
242+
sycl::event syevd_ev =
243+
syevd_batch_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_batch,
244+
eig_vals_batch, host_task_events, depends);
245+
}
246+
247+
sycl::event args_ev = dpctl::utils::keep_args_alive(
248+
exec_q, {eig_vecs, eig_vals}, host_task_events);
249+
250+
return std::make_pair(args_ev, args_ev);
251+
}
252+
253+
template <typename fnT, typename T>
254+
struct SyevdBatchContigFactory
255+
{
256+
fnT get()
257+
{
258+
if constexpr (types::SyevdBatchTypePairSupportFactory<T>::is_defined) {
259+
return syevd_batch_impl<T>;
260+
}
261+
else {
262+
return nullptr;
263+
}
264+
}
265+
};
266+
267+
void init_syevd_batch_dispatch_vector(void)
268+
{
269+
dpctl_td_ns::DispatchVectorBuilder<syevd_batch_impl_fn_ptr_t,
270+
SyevdBatchContigFactory,
271+
dpctl_td_ns::num_types>
272+
contig;
273+
contig.populate_dispatch_vector(syevd_batch_dispatch_vector);
274+
}
275+
} // namespace lapack
276+
} // namespace ext
277+
} // namespace backend
278+
} // namespace dpnp

dpnp/backend/extensions/lapack/types_matrix.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,24 @@ struct SyevdTypePairSupportFactory
383383
dpctl_td_ns::NotDefinedEntry>::is_defined;
384384
};
385385

386+
/**
387+
* @brief A factory to define pairs of supported types for which
388+
* MKL LAPACK library provides support in oneapi::mkl::lapack::syevd<T>
389+
* function.
390+
*
391+
* @tparam T Type of array containing input matrix A and an output arrays with
392+
* eigenvectors and eigenvectors.
393+
*/
394+
template <typename T>
395+
struct SyevdBatchTypePairSupportFactory
396+
{
397+
static constexpr bool is_defined = std::disjunction<
398+
dpctl_td_ns::TypePairDefinedEntry<T, double, T, double>,
399+
dpctl_td_ns::TypePairDefinedEntry<T, float, T, float>,
400+
// fall-through
401+
dpctl_td_ns::NotDefinedEntry>::is_defined;
402+
};
403+
386404
/**
387405
* @brief A factory to define pairs of supported types for which
388406
* MKL LAPACK library provides support in oneapi::mkl::lapack::ungqr_batch<T>

0 commit comments

Comments
 (0)