@@ -230,6 +230,123 @@ std::pair<sycl::event, sycl::event>
230
230
return std::make_pair (args_ev, syevd_ev);
231
231
}
232
232
233
+ std::pair<sycl::event, sycl::event>
234
+ syevd_batch (sycl::queue exec_q,
235
+ const std::int8_t jobz,
236
+ const std::int8_t upper_lower,
237
+ dpctl::tensor::usm_ndarray eig_vecs,
238
+ dpctl::tensor::usm_ndarray eig_vals,
239
+ const std::vector<sycl::event> &depends)
240
+ {
241
+ const int eig_vecs_nd = eig_vecs.get_ndim ();
242
+ const int eig_vals_nd = eig_vals.get_ndim ();
243
+
244
+ if (eig_vecs_nd != 3 ) {
245
+ throw py::value_error (" Unexpected ndim=" + std::to_string (eig_vecs_nd) +
246
+ " of an output array with eigenvectors" );
247
+ }
248
+ else if (eig_vals_nd != 2 ) {
249
+ throw py::value_error (" Unexpected ndim=" + std::to_string (eig_vals_nd) +
250
+ " of an output array with eigenvalues" );
251
+ }
252
+
253
+ const py::ssize_t *eig_vecs_shape = eig_vecs.get_shape_raw ();
254
+ const py::ssize_t *eig_vals_shape = eig_vals.get_shape_raw ();
255
+
256
+ if (eig_vecs_shape[1 ] != eig_vecs_shape[2 ]) {
257
+ throw py::value_error (
258
+ " The last two dimensions of 'eig_vecs' must be the same." );
259
+ }
260
+ else if (eig_vecs_shape[0 ] != eig_vals_shape[0 ] ||
261
+ eig_vecs_shape[1 ] != eig_vals_shape[1 ])
262
+ {
263
+ throw py::value_error (
264
+ " The shape of 'eig_vals' must be (batch_size, n), "
265
+ " where batch_size = " +
266
+ std::to_string (eig_vecs_shape[0 ]) +
267
+ " and n = " + std::to_string (eig_vecs_shape[1 ]));
268
+ }
269
+
270
+ size_t src_nelems (1 );
271
+
272
+ for (int i = 0 ; i < eig_vecs_nd; ++i) {
273
+ src_nelems *= static_cast <size_t >(eig_vecs_shape[i]);
274
+ }
275
+
276
+ if (src_nelems == 0 ) {
277
+ // nothing to do
278
+ return std::make_pair (sycl::event (), sycl::event ());
279
+ }
280
+
281
+ // check compatibility of execution queue and allocation queue
282
+ if (!dpctl::utils::queues_are_compatible (exec_q, {eig_vecs, eig_vals})) {
283
+ throw py::value_error (
284
+ " Execution queue is not compatible with allocation queues" );
285
+ }
286
+
287
+ auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
288
+ if (overlap (eig_vecs, eig_vals)) {
289
+ throw py::value_error (" Arrays with eigenvectors and eigenvalues are "
290
+ " overlapping segments of memory" );
291
+ }
292
+
293
+ bool is_eig_vecs_c_contig = eig_vecs.is_c_contiguous ();
294
+ bool is_eig_vals_c_contig = eig_vals.is_c_contiguous ();
295
+ if (!is_eig_vecs_c_contig) {
296
+ throw py::value_error (
297
+ " An array with input matrix / output eigenvectors "
298
+ " must be C-contiguous" );
299
+ }
300
+ else if (!is_eig_vals_c_contig) {
301
+ throw py::value_error (
302
+ " An array with output eigenvalues must be C-contiguous" );
303
+ }
304
+
305
+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
306
+ int eig_vecs_type_id =
307
+ array_types.typenum_to_lookup_id (eig_vecs.get_typenum ());
308
+ int eig_vals_type_id =
309
+ array_types.typenum_to_lookup_id (eig_vals.get_typenum ());
310
+
311
+ if (eig_vecs_type_id != eig_vals_type_id) {
312
+ throw py::value_error (
313
+ " Types of eigenvectors and eigenvalues are mismatched" );
314
+ }
315
+
316
+ syevd_impl_fn_ptr_t syevd_fn = syevd_dispatch_vector[eig_vecs_type_id];
317
+ if (syevd_fn == nullptr ) {
318
+ throw py::value_error (" No syevd implementation defined for a type of "
319
+ " eigenvectors and eigenvalues" );
320
+ }
321
+
322
+ char *eig_vecs_data = eig_vecs.get_data ();
323
+ char *eig_vals_data = eig_vals.get_data ();
324
+
325
+ const std::int64_t batch_size = eig_vecs_shape[0 ];
326
+ const std::int64_t n = eig_vecs_shape[1 ];
327
+ int elemsize = eig_vecs.get_elemsize ();
328
+
329
+ const oneapi::mkl::job jobz_val = static_cast <oneapi::mkl::job>(jobz);
330
+ const oneapi::mkl::uplo uplo_val =
331
+ static_cast <oneapi::mkl::uplo>(upper_lower);
332
+
333
+ std::vector<sycl::event> host_task_events;
334
+
335
+ for (std::int64_t i = 0 ; i < batch_size; ++i) {
336
+ char *eig_vecs_batch = eig_vecs_data + i * n * n * elemsize;
337
+ char *eig_vals_batch = eig_vals_data + i * n * elemsize;
338
+
339
+ sycl::event syevd_ev =
340
+ syevd_fn (exec_q, jobz_val, uplo_val, n, eig_vecs_batch,
341
+ eig_vals_batch, host_task_events, depends);
342
+ }
343
+
344
+ sycl::event args_ev = dpctl::utils::keep_args_alive (
345
+ exec_q, {eig_vecs, eig_vals}, host_task_events);
346
+
347
+ return std::make_pair (args_ev, args_ev);
348
+ }
349
+
233
350
template <typename fnT, typename T>
234
351
struct SyevdContigFactory
235
352
{
0 commit comments