Skip to content

Commit 7bca6ba

Browse files
Reduce implementation code
1 parent 19d0540 commit 7bca6ba

File tree

6 files changed

+117
-299
lines changed

6 files changed

+117
-299
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ set(_module_src
4141
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
4242
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
4343
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
44-
${CMAKE_CURRENT_SOURCE_DIR}/syevd_batch.cpp
4544
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
4645
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
4746
)

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ void init_dispatch_vectors(void)
6060
lapack_ext::init_orgqr_dispatch_vector();
6161
lapack_ext::init_potrf_batch_dispatch_vector();
6262
lapack_ext::init_potrf_dispatch_vector();
63-
lapack_ext::init_syevd_batch_dispatch_vector();
6463
lapack_ext::init_syevd_dispatch_vector();
6564
lapack_ext::init_ungqr_batch_dispatch_vector();
6665
lapack_ext::init_ungqr_dispatch_vector();

dpnp/backend/extensions/lapack/syevd.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,123 @@ std::pair<sycl::event, sycl::event>
230230
return std::make_pair(args_ev, syevd_ev);
231231
}
232232

233+
std::pair<sycl::event, sycl::event>
234+
syevd_batch(sycl::queue exec_q,
235+
const std::int8_t jobz,
236+
const std::int8_t upper_lower,
237+
dpctl::tensor::usm_ndarray eig_vecs,
238+
dpctl::tensor::usm_ndarray eig_vals,
239+
const std::vector<sycl::event> &depends)
240+
{
241+
const int eig_vecs_nd = eig_vecs.get_ndim();
242+
const int eig_vals_nd = eig_vals.get_ndim();
243+
244+
if (eig_vecs_nd != 3) {
245+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
246+
" of an output array with eigenvectors");
247+
}
248+
else if (eig_vals_nd != 2) {
249+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
250+
" of an output array with eigenvalues");
251+
}
252+
253+
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
254+
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
255+
256+
if (eig_vecs_shape[1] != eig_vecs_shape[2]) {
257+
throw py::value_error(
258+
"The last two dimensions of 'eig_vecs' must be the same.");
259+
}
260+
else if (eig_vecs_shape[0] != eig_vals_shape[0] ||
261+
eig_vecs_shape[1] != eig_vals_shape[1])
262+
{
263+
throw py::value_error(
264+
"The shape of 'eig_vals' must be (batch_size, n), "
265+
"where batch_size = " +
266+
std::to_string(eig_vecs_shape[0]) +
267+
" and n = " + std::to_string(eig_vecs_shape[1]));
268+
}
269+
270+
size_t src_nelems(1);
271+
272+
for (int i = 0; i < eig_vecs_nd; ++i) {
273+
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
274+
}
275+
276+
if (src_nelems == 0) {
277+
// nothing to do
278+
return std::make_pair(sycl::event(), sycl::event());
279+
}
280+
281+
// check compatibility of execution queue and allocation queue
282+
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
283+
throw py::value_error(
284+
"Execution queue is not compatible with allocation queues");
285+
}
286+
287+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
288+
if (overlap(eig_vecs, eig_vals)) {
289+
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
290+
"overlapping segments of memory");
291+
}
292+
293+
bool is_eig_vecs_c_contig = eig_vecs.is_c_contiguous();
294+
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
295+
if (!is_eig_vecs_c_contig) {
296+
throw py::value_error(
297+
"An array with input matrix / output eigenvectors "
298+
"must be C-contiguous");
299+
}
300+
else if (!is_eig_vals_c_contig) {
301+
throw py::value_error(
302+
"An array with output eigenvalues must be C-contiguous");
303+
}
304+
305+
auto array_types = dpctl_td_ns::usm_ndarray_types();
306+
int eig_vecs_type_id =
307+
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
308+
int eig_vals_type_id =
309+
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
310+
311+
if (eig_vecs_type_id != eig_vals_type_id) {
312+
throw py::value_error(
313+
"Types of eigenvectors and eigenvalues are mismatched");
314+
}
315+
316+
syevd_impl_fn_ptr_t syevd_fn = syevd_dispatch_vector[eig_vecs_type_id];
317+
if (syevd_fn == nullptr) {
318+
throw py::value_error("No syevd implementation defined for a type of "
319+
"eigenvectors and eigenvalues");
320+
}
321+
322+
char *eig_vecs_data = eig_vecs.get_data();
323+
char *eig_vals_data = eig_vals.get_data();
324+
325+
const std::int64_t batch_size = eig_vecs_shape[0];
326+
const std::int64_t n = eig_vecs_shape[1];
327+
int elemsize = eig_vecs.get_elemsize();
328+
329+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
330+
const oneapi::mkl::uplo uplo_val =
331+
static_cast<oneapi::mkl::uplo>(upper_lower);
332+
333+
std::vector<sycl::event> host_task_events;
334+
335+
for (std::int64_t i = 0; i < batch_size; ++i) {
336+
char *eig_vecs_batch = eig_vecs_data + i * n * n * elemsize;
337+
char *eig_vals_batch = eig_vals_data + i * n * elemsize;
338+
339+
sycl::event syevd_ev =
340+
syevd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_batch,
341+
eig_vals_batch, host_task_events, depends);
342+
}
343+
344+
sycl::event args_ev = dpctl::utils::keep_args_alive(
345+
exec_q, {eig_vecs, eig_vals}, host_task_events);
346+
347+
return std::make_pair(args_ev, args_ev);
348+
}
349+
233350
template <typename fnT, typename T>
234351
struct SyevdContigFactory
235352
{

dpnp/backend/extensions/lapack/syevd.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ extern std::pair<sycl::event, sycl::event>
5555
const std::vector<sycl::event> &depends = {});
5656

5757
extern void init_syevd_dispatch_vector(void);
58-
extern void init_syevd_batch_dispatch_vector(void);
5958
} // namespace lapack
6059
} // namespace ext
6160
} // namespace backend

0 commit comments

Comments
 (0)