Skip to content

Commit aa534f8

Browse files
Implement syevd_batch and heevd_batch (#1936)
* Implement syevd_batch and heevd_batch * Move include dpctl type_utils header to sourse files * Add memory alocation check for scratchpad * Add more checks for scratchpad_size * Move includes * Allocate memory for w with expected shape * Applied review comments * Add common_evd_checks to reduce dublicate code * Remove host_task_events from syevd and heevd * Applied review comments * Use init_evd_dispatch_table instead of init_evd_batch_dispatch_table * Move init_evd_dispatch_table to evd_common_utils.hpp * Add helper function check_zeros_shape * Implement alloc_scratchpad function to evd_batch_common.hpp * Make round_up_mult as inline * Add comment for check_zeros_shape * Make alloc_scratchpad as inline
1 parent e33a82b commit aa534f8

15 files changed

+851
-146
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ set(_module_src
3636
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
3737
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
39+
${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
4344
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
45+
${CMAKE_CURRENT_SOURCE_DIR}/syevd_batch.cpp
4446
${CMAKE_CURRENT_SOURCE_DIR}/ungqr.cpp
4547
${CMAKE_CURRENT_SOURCE_DIR}/ungqr_batch.cpp
4648
)

dpnp/backend/extensions/lapack/common_helpers.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,15 @@
2424
//*****************************************************************************
2525

2626
#pragma once
27+
#include <complex>
2728
#include <cstring>
29+
#include <pybind11/pybind11.h>
2830
#include <stdexcept>
2931

3032
namespace dpnp::extensions::lapack::helper
3133
{
34+
namespace py = pybind11;
35+
3236
template <typename T>
3337
struct value_type_of
3438
{
@@ -40,4 +44,23 @@ struct value_type_of<std::complex<T>>
4044
{
4145
using type = T;
4246
};
47+
48+
// Rounds up the number `value` to the nearest multiple of `mult`.
49+
template <typename intT>
50+
inline intT round_up_mult(intT value, intT mult)
51+
{
52+
intT q = (value + (mult - 1)) / mult;
53+
return q * mult;
54+
}
55+
56+
// Checks if the shape array has any non-zero dimension.
57+
inline bool check_zeros_shape(int ndim, const py::ssize_t *shape)
58+
{
59+
size_t src_nelems(1);
60+
61+
for (int i = 0; i < ndim; ++i) {
62+
src_nelems *= static_cast<size_t>(shape[i]);
63+
}
64+
return src_nelems == 0;
65+
}
4366
} // namespace dpnp::extensions::lapack::helper
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <oneapi/mkl.hpp>
29+
#include <pybind11/pybind11.h>
30+
31+
// dpctl tensor headers
32+
#include "utils/type_dispatch.hpp"
33+
34+
#include "common_helpers.hpp"
35+
#include "evd_common_utils.hpp"
36+
#include "types_matrix.hpp"
37+
38+
namespace dpnp::extensions::lapack::evd
39+
{
40+
typedef sycl::event (*evd_batch_impl_fn_ptr_t)(
41+
sycl::queue &,
42+
const oneapi::mkl::job,
43+
const oneapi::mkl::uplo,
44+
const std::int64_t,
45+
const std::int64_t,
46+
char *,
47+
char *,
48+
const std::vector<sycl::event> &);
49+
50+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
51+
namespace py = pybind11;
52+
53+
template <typename dispatchT>
54+
std::pair<sycl::event, sycl::event>
55+
evd_batch_func(sycl::queue &exec_q,
56+
const std::int8_t jobz,
57+
const std::int8_t upper_lower,
58+
dpctl::tensor::usm_ndarray &eig_vecs,
59+
dpctl::tensor::usm_ndarray &eig_vals,
60+
const std::vector<sycl::event> &depends,
61+
const dispatchT &evd_batch_dispatch_table)
62+
{
63+
const int eig_vecs_nd = eig_vecs.get_ndim();
64+
65+
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
66+
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
67+
68+
constexpr int expected_eig_vecs_nd = 3;
69+
constexpr int expected_eig_vals_nd = 2;
70+
71+
common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
72+
expected_eig_vecs_nd, expected_eig_vals_nd);
73+
74+
if (eig_vecs_shape[2] != eig_vals_shape[0] ||
75+
eig_vecs_shape[0] != eig_vals_shape[1])
76+
{
77+
throw py::value_error(
78+
"The shape of 'eig_vals' must be (batch_size, n), "
79+
"where batch_size = " +
80+
std::to_string(eig_vecs_shape[0]) +
81+
" and n = " + std::to_string(eig_vecs_shape[1]));
82+
}
83+
84+
// Ensure `batch_size` and `n` are non-zero, otherwise return empty events
85+
if (helper::check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
86+
// nothing to do
87+
return std::make_pair(sycl::event(), sycl::event());
88+
}
89+
90+
auto array_types = dpctl_td_ns::usm_ndarray_types();
91+
const int eig_vecs_type_id =
92+
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
93+
const int eig_vals_type_id =
94+
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
95+
96+
evd_batch_impl_fn_ptr_t evd_batch_fn =
97+
evd_batch_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
98+
if (evd_batch_fn == nullptr) {
99+
throw py::value_error(
100+
"Types of input vectors and result array are mismatched.");
101+
}
102+
103+
char *eig_vecs_data = eig_vecs.get_data();
104+
char *eig_vals_data = eig_vals.get_data();
105+
106+
const std::int64_t batch_size = eig_vecs_shape[2];
107+
const std::int64_t n = eig_vecs_shape[1];
108+
109+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
110+
const oneapi::mkl::uplo uplo_val =
111+
static_cast<oneapi::mkl::uplo>(upper_lower);
112+
113+
sycl::event evd_batch_ev =
114+
evd_batch_fn(exec_q, jobz_val, uplo_val, batch_size, n, eig_vecs_data,
115+
eig_vals_data, depends);
116+
117+
sycl::event ht_ev = dpctl::utils::keep_args_alive(
118+
exec_q, {eig_vecs, eig_vals}, {evd_batch_ev});
119+
120+
return std::make_pair(ht_ev, evd_batch_ev);
121+
}
122+
123+
template <typename T>
124+
inline T *alloc_scratchpad(std::int64_t scratchpad_size,
125+
std::int64_t n_linear_streams,
126+
sycl::queue &exec_q)
127+
{
128+
// Get padding size to ensure memory allocations are aligned to 256 bytes
129+
// for better performance
130+
const std::int64_t padding = 256 / sizeof(T);
131+
132+
if (scratchpad_size <= 0) {
133+
throw std::runtime_error(
134+
"Invalid scratchpad size: must be greater than zero."
135+
" Calculated scratchpad size: " +
136+
std::to_string(scratchpad_size));
137+
}
138+
139+
// Calculate the total scratchpad memory size needed for all linear
140+
// streams with proper alignment
141+
const size_t alloc_scratch_size =
142+
helper::round_up_mult(n_linear_streams * scratchpad_size, padding);
143+
144+
// Allocate memory for the total scratchpad
145+
T *scratchpad = sycl::malloc_device<T>(alloc_scratch_size, exec_q);
146+
if (!scratchpad) {
147+
throw std::runtime_error("Device allocation for scratchpad failed");
148+
}
149+
150+
return scratchpad;
151+
}
152+
} // namespace dpnp::extensions::lapack::evd

dpnp/backend/extensions/lapack/evd_common.hpp

Lines changed: 18 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,22 @@
2929
#include <pybind11/pybind11.h>
3030

3131
// dpctl tensor headers
32-
#include "utils/memory_overlap.hpp"
33-
#include "utils/output_validation.hpp"
3432
#include "utils/type_dispatch.hpp"
35-
#include "utils/type_utils.hpp"
3633

34+
#include "common_helpers.hpp"
35+
#include "evd_common_utils.hpp"
3736
#include "types_matrix.hpp"
3837

3938
namespace dpnp::extensions::lapack::evd
4039
{
40+
using dpnp::extensions::lapack::helper::check_zeros_shape;
41+
4142
typedef sycl::event (*evd_impl_fn_ptr_t)(sycl::queue &,
4243
const oneapi::mkl::job,
4344
const oneapi::mkl::uplo,
4445
const std::int64_t,
4546
char *,
4647
char *,
47-
std::vector<sycl::event> &,
4848
const std::vector<sycl::event> &);
4949

5050
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -61,70 +61,30 @@ std::pair<sycl::event, sycl::event>
6161
const dispatchT &evd_dispatch_table)
6262
{
6363
const int eig_vecs_nd = eig_vecs.get_ndim();
64-
const int eig_vals_nd = eig_vals.get_ndim();
65-
66-
if (eig_vecs_nd != 2) {
67-
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
68-
" of an output array with eigenvectors");
69-
}
70-
else if (eig_vals_nd != 1) {
71-
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
72-
" of an output array with eigenvalues");
73-
}
7464

7565
const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw();
7666
const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw();
7767

78-
if (eig_vecs_shape[0] != eig_vecs_shape[1]) {
79-
throw py::value_error("Output array with eigenvectors with be square");
80-
}
81-
else if (eig_vecs_shape[0] != eig_vals_shape[0]) {
82-
throw py::value_error(
83-
"Eigenvectors and eigenvalues have different shapes");
84-
}
68+
constexpr int expected_eig_vecs_nd = 2;
69+
constexpr int expected_eig_vals_nd = 1;
8570

86-
size_t src_nelems(1);
71+
common_evd_checks(exec_q, eig_vecs, eig_vals, eig_vecs_shape,
72+
expected_eig_vecs_nd, expected_eig_vals_nd);
8773

88-
for (int i = 0; i < eig_vecs_nd; ++i) {
89-
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
74+
if (eig_vecs_shape[0] != eig_vals_shape[0]) {
75+
throw py::value_error(
76+
"Eigenvectors and eigenvalues have different shapes");
9077
}
9178

92-
if (src_nelems == 0) {
79+
if (check_zeros_shape(eig_vecs_nd, eig_vecs_shape)) {
9380
// nothing to do
9481
return std::make_pair(sycl::event(), sycl::event());
9582
}
9683

97-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vecs);
98-
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(eig_vals);
99-
100-
// check compatibility of execution queue and allocation queue
101-
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals})) {
102-
throw py::value_error(
103-
"Execution queue is not compatible with allocation queues");
104-
}
105-
106-
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
107-
if (overlap(eig_vecs, eig_vals)) {
108-
throw py::value_error("Arrays with eigenvectors and eigenvalues are "
109-
"overlapping segments of memory");
110-
}
111-
112-
bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous();
113-
bool is_eig_vals_c_contig = eig_vals.is_c_contiguous();
114-
if (!is_eig_vecs_f_contig) {
115-
throw py::value_error(
116-
"An array with input matrix / output eigenvectors "
117-
"must be F-contiguous");
118-
}
119-
else if (!is_eig_vals_c_contig) {
120-
throw py::value_error(
121-
"An array with output eigenvalues must be C-contiguous");
122-
}
123-
12484
auto array_types = dpctl_td_ns::usm_ndarray_types();
125-
int eig_vecs_type_id =
85+
const int eig_vecs_type_id =
12686
array_types.typenum_to_lookup_id(eig_vecs.get_typenum());
127-
int eig_vals_type_id =
87+
const int eig_vals_type_id =
12888
array_types.typenum_to_lookup_id(eig_vals.get_typenum());
12989

13090
evd_impl_fn_ptr_t evd_fn =
@@ -142,25 +102,12 @@ std::pair<sycl::event, sycl::event>
142102
const oneapi::mkl::uplo uplo_val =
143103
static_cast<oneapi::mkl::uplo>(upper_lower);
144104

145-
std::vector<sycl::event> host_task_events;
146105
sycl::event evd_ev = evd_fn(exec_q, jobz_val, uplo_val, n, eig_vecs_data,
147-
eig_vals_data, host_task_events, depends);
106+
eig_vals_data, depends);
148107

149-
sycl::event args_ev = dpctl::utils::keep_args_alive(
150-
exec_q, {eig_vecs, eig_vals}, host_task_events);
108+
sycl::event ht_ev =
109+
dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, {evd_ev});
151110

152-
return std::make_pair(args_ev, evd_ev);
153-
}
154-
155-
template <typename dispatchT,
156-
template <typename fnT, typename T, typename RealT>
157-
typename factoryT>
158-
void init_evd_dispatch_table(
159-
dispatchT evd_dispatch_table[][dpctl_td_ns::num_types])
160-
{
161-
dpctl_td_ns::DispatchTableBuilder<dispatchT, factoryT,
162-
dpctl_td_ns::num_types>
163-
contig;
164-
contig.populate_dispatch_table(evd_dispatch_table);
111+
return std::make_pair(ht_ev, evd_ev);
165112
}
166113
} // namespace dpnp::extensions::lapack::evd

0 commit comments

Comments
 (0)