1
1
// *****************************************************************************
2
- // Copyright (c) 2023- 2024, Intel Corporation
2
+ // Copyright (c) 2024, Intel Corporation
3
3
// All rights reserved.
4
4
//
5
5
// Redistribution and use in source and binary forms, with or without
23
23
// THE POSSIBILITY OF SUCH DAMAGE.
24
24
// *****************************************************************************
25
25
26
- #include < pybind11/pybind11.h>
27
-
28
- // dpctl tensor headers
29
- #include " utils/memory_overlap.hpp"
30
- #include " utils/type_utils.hpp"
31
-
32
26
#include " heevd.hpp"
33
- #include " types_matrix.hpp"
34
-
35
- #include " dpnp_utils.hpp"
36
27
37
28
namespace dpnp
38
29
{
@@ -43,23 +34,10 @@ namespace ext
43
34
namespace lapack
44
35
{
45
36
namespace mkl_lapack = oneapi::mkl::lapack;
46
- namespace py = pybind11;
47
37
namespace type_utils = dpctl::tensor::type_utils;
48
38
49
- typedef sycl::event (*heevd_impl_fn_ptr_t )(sycl::queue,
50
- const oneapi::mkl::job,
51
- const oneapi::mkl::uplo,
52
- const std::int64_t ,
53
- char *,
54
- char *,
55
- std::vector<sycl::event> &,
56
- const std::vector<sycl::event> &);
57
-
58
- static heevd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types]
59
- [dpctl_td_ns::num_types];
60
-
61
39
template <typename T, typename RealT>
62
- static sycl::event heevd_impl (sycl::queue exec_q,
40
+ static sycl::event heevd_impl (sycl::queue & exec_q,
63
41
const oneapi::mkl::job jobz,
64
42
const oneapi::mkl::uplo upper_lower,
65
43
const std::int64_t n,
@@ -128,104 +106,8 @@ static sycl::event heevd_impl(sycl::queue exec_q,
128
106
cgh.host_task ([ctx, scratchpad]() { sycl::free (scratchpad, ctx); });
129
107
});
130
108
host_task_events.push_back (clean_up_event);
131
- return heevd_event;
132
- }
133
-
134
- std::pair<sycl::event, sycl::event>
135
- heevd (sycl::queue exec_q,
136
- const std::int8_t jobz,
137
- const std::int8_t upper_lower,
138
- dpctl::tensor::usm_ndarray eig_vecs,
139
- dpctl::tensor::usm_ndarray eig_vals,
140
- const std::vector<sycl::event> &depends)
141
- {
142
- const int eig_vecs_nd = eig_vecs.get_ndim ();
143
- const int eig_vals_nd = eig_vals.get_ndim ();
144
-
145
- if (eig_vecs_nd != 2 ) {
146
- throw py::value_error (" Unexpected ndim=" + std::to_string (eig_vecs_nd) +
147
- " of an output array with eigenvectors" );
148
- }
149
- else if (eig_vals_nd != 1 ) {
150
- throw py::value_error (" Unexpected ndim=" + std::to_string (eig_vals_nd) +
151
- " of an output array with eigenvalues" );
152
- }
153
-
154
- const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw ();
155
- const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw ();
156
-
157
- if (eig_vecs_shape[0 ] != eig_vecs_shape[1 ]) {
158
- throw py::value_error (" Output array with eigenvectors with be square" );
159
- }
160
- else if (eig_vecs_shape[0 ] != eig_vals_shape[0 ]) {
161
- throw py::value_error (
162
- " Eigenvectors and eigenvalues have different shapes" );
163
- }
164
-
165
- size_t src_nelems (1 );
166
-
167
- for (int i = 0 ; i < eig_vecs_nd; ++i) {
168
- src_nelems *= static_cast <size_t >(eig_vecs_shape[i]);
169
- }
170
-
171
- if (src_nelems == 0 ) {
172
- // nothing to do
173
- return std::make_pair (sycl::event (), sycl::event ());
174
- }
175
-
176
- // check compatibility of execution queue and allocation queue
177
- if (!dpctl::utils::queues_are_compatible (exec_q, {eig_vecs, eig_vals})) {
178
- throw py::value_error (
179
- " Execution queue is not compatible with allocation queues" );
180
- }
181
-
182
- auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
183
- if (overlap (eig_vecs, eig_vals)) {
184
- throw py::value_error (" Arrays with eigenvectors and eigenvalues are "
185
- " overlapping segments of memory" );
186
- }
187
-
188
- bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous ();
189
- bool is_eig_vals_c_contig = eig_vals.is_c_contiguous ();
190
- if (!is_eig_vecs_f_contig) {
191
- throw py::value_error (
192
- " An array with input matrix / output eigenvectors "
193
- " must be F-contiguous" );
194
- }
195
- else if (!is_eig_vals_c_contig) {
196
- throw py::value_error (
197
- " An array with output eigenvalues must be C-contiguous" );
198
- }
199
-
200
- auto array_types = dpctl_td_ns::usm_ndarray_types ();
201
- int eig_vecs_type_id =
202
- array_types.typenum_to_lookup_id (eig_vecs.get_typenum ());
203
- int eig_vals_type_id =
204
- array_types.typenum_to_lookup_id (eig_vals.get_typenum ());
205
-
206
- heevd_impl_fn_ptr_t heevd_fn =
207
- heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
208
- if (heevd_fn == nullptr ) {
209
- throw py::value_error (" No heevd implementation defined for a pair of "
210
- " type for eigenvectors and eigenvalues" );
211
- }
212
-
213
- char *eig_vecs_data = eig_vecs.get_data ();
214
- char *eig_vals_data = eig_vals.get_data ();
215
-
216
- const std::int64_t n = eig_vecs_shape[0 ];
217
- const oneapi::mkl::job jobz_val = static_cast <oneapi::mkl::job>(jobz);
218
- const oneapi::mkl::uplo uplo_val =
219
- static_cast <oneapi::mkl::uplo>(upper_lower);
220
109
221
- std::vector<sycl::event> host_task_events;
222
- sycl::event heevd_ev =
223
- heevd_fn (exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data,
224
- host_task_events, depends);
225
-
226
- sycl::event args_ev = dpctl::utils::keep_args_alive (
227
- exec_q, {eig_vecs, eig_vals}, host_task_events);
228
- return std::make_pair (args_ev, heevd_ev);
110
+ return heevd_event;
229
111
}
230
112
231
113
template <typename fnT, typename T, typename RealT>
@@ -243,12 +125,35 @@ struct HeevdContigFactory
243
125
}
244
126
};
245
127
246
- void init_heevd_dispatch_table (void )
128
+ using evd::evd_impl_fn_ptr_t ;
129
+
130
+ void init_heevd (py::module_ m)
247
131
{
248
- dpctl_td_ns::DispatchTableBuilder<heevd_impl_fn_ptr_t , HeevdContigFactory,
249
- dpctl_td_ns::num_types>
250
- contig;
251
- contig.populate_dispatch_table (heevd_dispatch_table);
132
+ using arrayT = dpctl::tensor::usm_ndarray;
133
+ using event_vecT = std::vector<sycl::event>;
134
+
135
+ static evd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types]
136
+ [dpctl_td_ns::num_types];
137
+
138
+ {
139
+ evd::init_evd_dispatch_table<evd_impl_fn_ptr_t , HeevdContigFactory>(
140
+ heevd_dispatch_table);
141
+
142
+ auto heevd_pyapi = [&](sycl::queue &exec_q, const std::int8_t jobz,
143
+ const std::int8_t upper_lower, arrayT &eig_vecs,
144
+ arrayT &eig_vals,
145
+ const event_vecT &depends = {}) {
146
+ return evd::evd_func (exec_q, jobz, upper_lower, eig_vecs, eig_vals,
147
+ depends, heevd_dispatch_table);
148
+ };
149
+
150
+ m.def (" _heevd" , heevd_pyapi,
151
+ " Call `heevd` from OneMKL LAPACK library to return "
152
+ " the eigenvalues and eigenvectors of a complex Hermitian matrix" ,
153
+ py::arg (" sycl_queue" ), py::arg (" jobz" ), py::arg (" upper_lower" ),
154
+ py::arg (" eig_vecs" ), py::arg (" eig_vals" ),
155
+ py::arg (" depends" ) = py::list ());
156
+ }
252
157
}
253
158
} // namespace lapack
254
159
} // namespace ext
0 commit comments