Skip to content

Commit e96454e

Browse files
Implement heevd_batch via heevd call
1 parent 7bca6ba commit e96454e

File tree

4 files changed

+142
-79
lines changed

4 files changed

+142
-79
lines changed

dpnp/backend/extensions/lapack/heevd.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,120 @@ std::pair<sycl::event, sycl::event>
228228
return std::make_pair(args_ev, heevd_ev);
229229
}
230230

231+
std::pair<sycl::event, sycl::event>
232+
heevd_batch(sycl::queue exec_q,
233+
const std::int8_t jobz,
234+
const std::int8_t upper_lower,
235+
dpctl::tensor::usm_ndarray eig_vecs,
236+
dpctl::tensor::usm_ndarray eig_vals,
237+
const std::vector<sycl::event> &depends)
238+
{
239+
const int eig_vecs_nd = eig_vecs.get_ndim();
240+
const int eig_vals_nd = eig_vals.get_ndim();
241+
242+
if (eig_vecs_nd != 3) {
243+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
244+
" of an output array with eigenvectors");
245+
}
246+
else if (eig_vals_nd != 2) {
247+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
248+
" of an output array with eigenvalues");
249+
}
250+
251+
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
252+
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
253+
254+
if (eig_vecs_shape[1] != eig_vecs_shape[2]) {
255+
throw py::value_error(
256+
"The last two dimensions of 'eig_vecs' must be the same.");
257+
}
258+
else if (eig_vecs_shape[0] != eig_vals_shape[0] ||
259+
eig_vecs_shape[1] != eig_vals_shape[1])
260+
{
261+
throw py::value_error(
262+
"The shape of 'eig_vals' must be (batch_size, n), "
263+
"where batch_size = " +
264+
std::to_string(eig_vecs_shape[0]) +
265+
" and n = " + std::to_string(eig_vecs_shape[1]));
266+
}
267+
268+
size_t src_nelems(1);
269+
270+
for (int i = 0; i < eig_vecs_nd; ++i) {
271+
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
272+
}
273+
274+
if (src_nelems == 0) {
275+
// nothing to do
276+
return std::make_pair(sycl::event(), sycl::event());
277+
}
278+
279+
// check compatibility of execution queue and allocation queue
280+
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
281+
throw py::value_error(
282+
"Execution queue is not compatible with allocation queues");
283+
}
284+
285+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
286+
if (overlap(eig_vecs, eig_vals)) {
287+
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
288+
"overlapping segments of memory");
289+
}
290+
291+
bool is_eig_vecs_c_contig = eig_vecs.is_c_contiguous();
292+
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
293+
if (!is_eig_vecs_c_contig) {
294+
throw py::value_error(
295+
"An array with input matrix / output eigenvectors "
296+
"must be C-contiguous");
297+
}
298+
else if (!is_eig_vals_c_contig) {
299+
throw py::value_error(
300+
"An array with output eigenvalues must be C-contiguous");
301+
}
302+
303+
auto array_types = dpctl_td_ns::usm_ndarray_types();
304+
int eig_vecs_type_id =
305+
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
306+
int eig_vals_type_id =
307+
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
308+
309+
heevd_impl_fn_ptr_t heevd_fn =
310+
heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
311+
if (heevd_fn == nullptr) {
312+
throw py::value_error("No heevd implementation defined for a pair of "
313+
"type for eigenvectors and eigenvalues");
314+
}
315+
316+
char *eig_vecs_data = eig_vecs.get_data();
317+
char *eig_vals_data = eig_vals.get_data();
318+
319+
const std::int64_t batch_size = eig_vecs_shape[0];
320+
const std::int64_t n = eig_vecs_shape[1];
321+
int vecs_elemsize = eig_vecs.get_elemsize();
322+
int vals_elemsize = eig_vals.get_elemsize();
323+
324+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
325+
const oneapi::mkl::uplo uplo_val =
326+
static_cast<oneapi::mkl::uplo>(upper_lower);
327+
328+
std::vector<sycl::event> host_task_events;
329+
330+
for (std::int64_t i = 0; i < batch_size; ++i) {
331+
char *eig_vecs_batch = eig_vecs_data + i * n * n * vecs_elemsize;
332+
char *eig_vals_batch = eig_vals_data + i * n * vals_elemsize;
333+
334+
sycl::event heevd_ev =
335+
heevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_batch,
336+
eig_vals_batch, host_task_events, depends);
337+
}
338+
339+
sycl::event args_ev = dpctl::utils::keep_args_alive(
340+
exec_q, {eig_vecs, eig_vals}, host_task_events);
341+
342+
return std::make_pair(args_ev, args_ev);
343+
}
344+
231345
template <typename fnT, typename T, typename RealT>
232346
struct HeevdContigFactory
233347
{

dpnp/backend/extensions/lapack/heevd.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ extern std::pair<sycl::event, sycl::event>
4646
dpctl::tensor::usm_ndarray eig_vals,
4747
const std::vector<sycl::event> &depends = {});
4848

49+
extern std::pair<sycl::event, sycl::event>
50+
heevd_batch(sycl::queue exec_q,
51+
const std::int8_t jobz,
52+
const std::int8_t upper_lower,
53+
dpctl::tensor::usm_ndarray eig_vecs,
54+
dpctl::tensor::usm_ndarray eig_vals,
55+
const std::vector<sycl::event> &depends = {});
56+
4957
extern void init_heevd_dispatch_table(void);
5058
} // namespace lapack
5159
} // namespace ext

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ PYBIND11_MODULE(_lapack_impl, m)
146146
py::arg("eig_vecs"), py::arg("eig_vals"),
147147
py::arg("depends") = py::list());
148148

149+
m.def("_heevd_batch", &lapack_ext::heevd_batch,
150+
"Call `heevd` from OneMKL LAPACK library in a loop to return "
151+
"the eigenvalues and eigenvectors of a batch of complex Hermitian "
152+
"matrices",
153+
py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"),
154+
py::arg("eig_vecs"), py::arg("eig_vals"),
155+
py::arg("depends") = py::list());
156+
149157
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
150158
"Call `_orgqr_batch` from OneMKL LAPACK library to return "
151159
"the real orthogonal matrix Qi of the QR factorization "

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 12 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -103,93 +103,26 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
103103
jobz = _jobz[eigen_mode]
104104
uplo = _upper_lower[UPLO]
105105

106-
# Get LAPACK function (_syevd for real or _heevd for complex data types)
106+
# Get LAPACK function (_syevd_batch for real or _heevd_batch
107+
# for complex data types)
107108
# to compute all eigenvalues and, optionally, all eigenvectors
108109
lapack_func = (
109-
"_heevd" if dpnp.issubdtype(v_type, dpnp.complexfloating) else "_syevd"
110+
"_heevd_batch"
111+
if dpnp.issubdtype(v_type, dpnp.complexfloating)
112+
else "_syevd_batch"
110113
)
111114

112115
a_sycl_queue = a.sycl_queue
113-
114-
new = True
115-
116-
if not new or lapack_func == "_heevd":
117-
is_cpu_device = a.sycl_device.has_aspect_cpu
118-
orig_shape = a.shape
119-
# get 3d input array by reshape
120-
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
121-
a_usm_arr = dpnp.get_usm_ndarray(a)
122-
123-
# allocate a memory for dpnp array of eigenvalues
124-
w = dpnp.empty_like(
125-
a,
126-
shape=orig_shape[:-1],
127-
dtype=w_type,
128-
)
129-
w_orig_shape = w.shape
130-
# get 2d dpnp array with eigenvalues by reshape
131-
w = w.reshape(-1, w_orig_shape[-1])
132-
133-
a_order = "C" if a.flags.c_contiguous else "F"
134-
135-
# need to loop over the 1st dimension to get eigenvalues and
136-
# eigenvectors of 3d matrix A
137-
batch_size = a.shape[0]
138-
eig_vecs = [None] * batch_size
139-
ht_list_ev = [None] * batch_size * 2
140-
for i in range(batch_size):
141-
# oneMKL LAPACK assumes fortran-like array as input, so
142-
# allocate a memory with 'F' order for dpnp array of eigenvectors
143-
eig_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=v_type)
144-
145-
# use DPCTL tensor function to fill the array of eigenvectors with
146-
# content of input array
147-
ht_list_ev[2 * i], copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
148-
src=a_usm_arr[i],
149-
dst=eig_vecs[i].get_array(),
150-
sycl_queue=a_sycl_queue,
151-
)
152-
153-
# TODO: Remove this w/a when MKLD-17201 is solved.
154-
# Waiting for a host task executing an OneMKL LAPACK syevd/heevd
155-
# call on CPU causes deadlock due to serialization of all host tasks
156-
# in the queue.
157-
# We need to wait for each host tasks before calling _seyvd and
158-
# _heevd to avoid deadlock.
159-
if is_cpu_device:
160-
ht_list_ev[2 * i].wait()
161-
162-
# call LAPACK extension function to get eigenvalues and
163-
# eigenvectors of a portion of matrix A
164-
ht_list_ev[2 * i + 1], _ = getattr(li, lapack_func)(
165-
a_sycl_queue,
166-
jobz,
167-
uplo,
168-
eig_vecs[i].get_array(),
169-
w[i].get_array(),
170-
depends=[copy_ev],
171-
)
172-
173-
dpctl.SyclEvent.wait_for(ht_list_ev)
174-
175-
w = w.reshape(w_orig_shape)
176-
177-
if eigen_mode == "V":
178-
# combine the list of eigenvectors into a single array
179-
v = dpnp.array(eig_vecs, order=a_order).reshape(orig_shape)
180-
return w, v
181-
return w
182-
183116
a_orig_shape = a.shape
184117
# get 3d input array by reshape
185118
a = a.reshape(-1, a_orig_shape[-2], a_orig_shape[-1])
186119

187-
# oneMKL LAPACK syevd overwrites `a` and
188-
# assumes fortran-like array as input.
189-
# To use C-contiguous arrays, we transpose the last two dimensions
190-
# before passing to syevd.
191-
# This transposition is effective because each batch
192-
# in the input array `a` is square.
120+
# oneMKL LAPACK syevd/heevd overwrites `a` and assumes fortran-like array
121+
# as input.
122+
# To use C-contiguous arrays, we transpose the last two dimensions before
123+
# passing to syevd/heevd.
124+
# This transposition is effective because each batch in the input array `a`
125+
# is square.
193126
a = a.transpose((0, 2, 1))
194127
a_usm_arr = dpnp.get_usm_ndarray(a)
195128

@@ -212,7 +145,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
212145
# get 2d dpnp array with eigenvalues by reshape
213146
w = w.reshape(-1, w_orig_shape[-1])
214147

215-
ht_ev, _ = li._syevd_batch(
148+
ht_ev, _ = getattr(li, lapack_func)(
216149
a_sycl_queue,
217150
jobz,
218151
uplo,

0 commit comments

Comments
 (0)