Skip to content

Commit 8ac7f88

Browse files
Update lapack syevd/heevd implementations (#1891)
* Use dispatch_table for syevd * Implement common logic for lapack syevd/heevd * Remove junk code * Add source files for syevd and heevd --------- Co-authored-by: Anton <[email protected]>
1 parent 8ef7568 commit 8ac7f88

File tree

7 files changed

+266
-305
lines changed

7 files changed

+266
-305
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl.hpp>
29+
#include <pybind11/pybind11.h>
30+
31+
// dpctl tensor headers
32+
#include "utils/memory_overlap.hpp"
33+
#include "utils/output_validation.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "types_matrix.hpp"
38+
39+
namespace dpnp
40+
{
41+
namespace backend
42+
{
43+
namespace ext
44+
{
45+
namespace lapack
46+
{
47+
namespace evd
48+
{
49+
typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &,
50+
const oneapi::mkl::job,
51+
const oneapi::mkl::uplo,
52+
const std::int64_t,
53+
char *,
54+
char *,
55+
std::vector<sycl::event> &,
56+
const std::vector<sycl::event> &);
57+
58+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
59+
namespace py = pybind11;
60+
61+
template <typename dispatchT>
62+
std::pair<sycl::event, sycl::event>
63+
evd_func(sycl::queue &exec_q,
64+
const std::int8_t jobz,
65+
const std::int8_t upper_lower,
66+
dpctl::tensor::usm_ndarray &eig_vecs,
67+
dpctl::tensor::usm_ndarray &eig_vals,
68+
const std::vector<sycl::event> &depends,
69+
const dispatchT &evd_dispatch_table)
70+
{
71+
const int eig_vecs_nd = eig_vecs.get_ndim();
72+
const int eig_vals_nd = eig_vals.get_ndim();
73+
74+
if (eig_vecs_nd != 2) {
75+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
76+
" of an output array with eigenvectors");
77+
}
78+
else if (eig_vals_nd != 1) {
79+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
80+
" of an output array with eigenvalues");
81+
}
82+
83+
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
84+
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
85+
86+
if (eig_vecs_shape[0] != eig_vecs_shape[1]) {
87+
throw py::value_error("Output array with eigenvectors with be square");
88+
}
89+
else if (eig_vecs_shape[0] != eig_vals_shape[0]) {
90+
throw py::value_error(
91+
"Eigenvectors and eigenvalues have different shapes");
92+
}
93+
94+
size_t src_nelems(1);
95+
96+
for (int i = 0; i < eig_vecs_nd; ++i) {
97+
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
98+
}
99+
100+
if (src_nelems == 0) {
101+
// nothing to do
102+
return std::make_pair(sycl::event(), sycl::event());
103+
}
104+
105+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vecs);
106+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vals);
107+
108+
// check compatibility of execution queue and allocation queue
109+
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
110+
throw py::value_error(
111+
"Execution queue is not compatible with allocation queues");
112+
}
113+
114+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
115+
if (overlap(eig_vecs, eig_vals)) {
116+
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
117+
"overlapping segments of memory");
118+
}
119+
120+
bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
121+
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
122+
if (!is_eig_vecs_f_contig) {
123+
throw py::value_error(
124+
"An array with input matrix / output eigenvectors "
125+
"must be F-contiguous");
126+
}
127+
else if (!is_eig_vals_c_contig) {
128+
throw py::value_error(
129+
"An array with output eigenvalues must be C-contiguous");
130+
}
131+
132+
auto array_types = dpctl_td_ns::usm_ndarray_types();
133+
int eig_vecs_type_id =
134+
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
135+
int eig_vals_type_id =
136+
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
137+
138+
evd_impl_fn_ptr_t evd_fn =
139+
evd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
140+
if (evd_fn == nullptr) {
141+
throw py::value_error(
142+
"Types of input vectors and result array are mismatched.");
143+
}
144+
145+
char *eig_vecs_data = eig_vecs.get_data();
146+
char *eig_vals_data = eig_vals.get_data();
147+
148+
const std::int64_t n = eig_vecs_shape[0];
149+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
150+
const oneapi::mkl::uplo uplo_val =
151+
static_cast<oneapi::mkl::uplo>(upper_lower);
152+
153+
std::vector<sycl::event> host_task_events;
154+
sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data,
155+
eig_vals_data, host_task_events, depends);
156+
157+
sycl::event args_ev = dpctl::utils::keep_args_alive(
158+
exec_q, {eig_vecs, eig_vals}, host_task_events);
159+
160+
return std::make_pair(args_ev, evd_ev);
161+
}
162+
163+
template <typename dispatchT,
164+
template <typename fnT, typename T, typename RealT>
165+
typename factoryT>
166+
void init_evd_dispatch_table(
167+
dispatchT evd_dispatch_table[][dpctl_td_ns::num_types])
168+
{
169+
dpctl_td_ns::DispatchTableBuilder<dispatchT, factoryT,
170+
dpctl_td_ns::num_types>
171+
contig;
172+
contig.populate_dispatch_table(evd_dispatch_table);
173+
}
174+
} // namespace evd
175+
} // namespace lapack
176+
} // namespace ext
177+
} // namespace backend
178+
} // namespace dpnp

dpnp/backend/extensions/lapack/heevd.cpp

Lines changed: 31 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2023-2024, Intel Corporation
2+
// Copyright (c) 2024, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -23,16 +23,7 @@
2323
// THE POSSIBILITY OF SUCH DAMAGE.
2424
//*****************************************************************************
2525

26-
#include <pybind11/pybind11.h>
27-
28-
// dpctl tensor headers
29-
#include "utils/memory_overlap.hpp"
30-
#include "utils/type_utils.hpp"
31-
3226
#include "heevd.hpp"
33-
#include "types_matrix.hpp"
34-
35-
#include "dpnp_utils.hpp"
3627

3728
namespace dpnp
3829
{
@@ -43,23 +34,10 @@ namespace ext
4334
namespace lapack
4435
{
4536
namespace mkl_lapack = oneapi::mkl::lapack;
46-
namespace py = pybind11;
4737
namespace type_utils = dpctl::tensor::type_utils;
4838

49-
typedef sycl::event (*heevd_impl_fn_ptr_t)(sycl::queue,
50-
const oneapi::mkl::job,
51-
const oneapi::mkl::uplo,
52-
const std::int64_t,
53-
char *,
54-
char *,
55-
std::vector<sycl::event> &,
56-
const std::vector<sycl::event> &);
57-
58-
static heevd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types]
59-
[dpctl_td_ns::num_types];
60-
6139
template <typename T, typename RealT>
62-
static sycl::event heevd_impl(sycl::queue exec_q,
40+
static sycl::event heevd_impl(sycl::queue &exec_q,
6341
const oneapi::mkl::job jobz,
6442
const oneapi::mkl::uplo upper_lower,
6543
const std::int64_t n,
@@ -128,104 +106,8 @@ static sycl::event heevd_impl(sycl::queue exec_q,
128106
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
129107
});
130108
host_task_events.push_back(clean_up_event);
131-
return heevd_event;
132-
}
133-
134-
std::pair<sycl::event, sycl::event>
135-
heevd(sycl::queue exec_q,
136-
const std::int8_t jobz,
137-
const std::int8_t upper_lower,
138-
dpctl::tensor::usm_ndarray eig_vecs,
139-
dpctl::tensor::usm_ndarray eig_vals,
140-
const std::vector<sycl::event> &depends)
141-
{
142-
const int eig_vecs_nd = eig_vecs.get_ndim();
143-
const int eig_vals_nd = eig_vals.get_ndim();
144-
145-
if (eig_vecs_nd != 2) {
146-
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
147-
" of an output array with eigenvectors");
148-
}
149-
else if (eig_vals_nd != 1) {
150-
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
151-
" of an output array with eigenvalues");
152-
}
153-
154-
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
155-
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
156-
157-
if (eig_vecs_shape[0] != eig_vecs_shape[1]) {
158-
throw py::value_error("Output array with eigenvectors with be square");
159-
}
160-
else if (eig_vecs_shape[0] != eig_vals_shape[0]) {
161-
throw py::value_error(
162-
"Eigenvectors and eigenvalues have different shapes");
163-
}
164-
165-
size_t src_nelems(1);
166-
167-
for (int i = 0; i < eig_vecs_nd; ++i) {
168-
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
169-
}
170-
171-
if (src_nelems == 0) {
172-
// nothing to do
173-
return std::make_pair(sycl::event(), sycl::event());
174-
}
175-
176-
// check compatibility of execution queue and allocation queue
177-
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
178-
throw py::value_error(
179-
"Execution queue is not compatible with allocation queues");
180-
}
181-
182-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
183-
if (overlap(eig_vecs, eig_vals)) {
184-
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
185-
"overlapping segments of memory");
186-
}
187-
188-
bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
189-
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
190-
if (!is_eig_vecs_f_contig) {
191-
throw py::value_error(
192-
"An array with input matrix / output eigenvectors "
193-
"must be F-contiguous");
194-
}
195-
else if (!is_eig_vals_c_contig) {
196-
throw py::value_error(
197-
"An array with output eigenvalues must be C-contiguous");
198-
}
199-
200-
auto array_types = dpctl_td_ns::usm_ndarray_types();
201-
int eig_vecs_type_id =
202-
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
203-
int eig_vals_type_id =
204-
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
205-
206-
heevd_impl_fn_ptr_t heevd_fn =
207-
heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
208-
if (heevd_fn == nullptr) {
209-
throw py::value_error("No heevd implementation defined for a pair of "
210-
"type for eigenvectors and eigenvalues");
211-
}
212-
213-
char *eig_vecs_data = eig_vecs.get_data();
214-
char *eig_vals_data = eig_vals.get_data();
215-
216-
const std::int64_t n = eig_vecs_shape[0];
217-
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
218-
const oneapi::mkl::uplo uplo_val =
219-
static_cast<oneapi::mkl::uplo>(upper_lower);
220109

221-
std::vector<sycl::event> host_task_events;
222-
sycl::event heevd_ev =
223-
heevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data,
224-
host_task_events, depends);
225-
226-
sycl::event args_ev = dpctl::utils::keep_args_alive(
227-
exec_q, {eig_vecs, eig_vals}, host_task_events);
228-
return std::make_pair(args_ev, heevd_ev);
110+
return heevd_event;
229111
}
230112

231113
template <typename fnT, typename T, typename RealT>
@@ -243,12 +125,35 @@ struct HeevdContigFactory
243125
}
244126
};
245127

246-
void init_heevd_dispatch_table(void)
128+
using evd::evd_impl_fn_ptr_t;
129+
130+
void init_heevd(py::module_ m)
247131
{
248-
dpctl_td_ns::DispatchTableBuilder<heevd_impl_fn_ptr_t, HeevdContigFactory,
249-
dpctl_td_ns::num_types>
250-
contig;
251-
contig.populate_dispatch_table(heevd_dispatch_table);
132+
using arrayT = dpctl::tensor::usm_ndarray;
133+
using event_vecT = std::vector<sycl::event>;
134+
135+
static evd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types]
136+
[dpctl_td_ns::num_types];
137+
138+
{
139+
evd::init_evd_dispatch_table<evd_impl_fn_ptr_t, HeevdContigFactory>(
140+
heevd_dispatch_table);
141+
142+
auto heevd_pyapi = [&](sycl::queue &exec_q, const std::int8_t jobz,
143+
const std::int8_t upper_lower, arrayT &eig_vecs,
144+
arrayT &eig_vals,
145+
const event_vecT &depends = {}) {
146+
return evd::evd_func(exec_q, jobz, upper_lower, eig_vecs, eig_vals,
147+
depends, heevd_dispatch_table);
148+
};
149+
150+
m.def("_heevd", heevd_pyapi,
151+
"Call `heevd` from OneMKL LAPACK library to return "
152+
"the eigenvalues and eigenvectors of a complex Hermitian matrix",
153+
py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"),
154+
py::arg("eig_vecs"), py::arg("eig_vals"),
155+
py::arg("depends") = py::list());
156+
}
252157
}
253158
} // namespace lapack
254159
} // namespace ext

0 commit comments

Comments
 (0)