@@ -71,13 +71,11 @@ py_gemv(sycl::queue q,
71
71
" USM allocation is not bound to the context in execution queue." );
72
72
}
73
73
74
- int mat_flags = matrix.get_flags ();
75
- int v_flags = vector.get_flags ();
76
- int r_flags = result.get_flags ();
74
+ auto &api = dpctl::detail::dpctl_capi::get ();
77
75
78
- if (!((mat_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
79
- (v_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
80
- (r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))))
76
+ if (!((matrix. is_c_contiguous ( )) &&
77
+ (vector. is_c_contiguous () || vector. is_f_contiguous ( )) &&
78
+ (result. is_c_contiguous () || result. is_f_contiguous ( ))))
81
79
{
82
80
throw std::runtime_error (" Arrays must be contiguous." );
83
81
}
@@ -87,8 +85,8 @@ py_gemv(sycl::queue q,
87
85
int r_typenum = result.get_typenum ();
88
86
89
87
if ((mat_typenum != v_typenum) || (r_typenum != v_typenum) ||
90
- !((v_typenum == UAR_DOUBLE ) || (v_typenum == UAR_FLOAT ) ||
91
- (v_typenum == UAR_CDOUBLE ) || (v_typenum == UAR_CFLOAT )))
88
+ !((v_typenum == api. UAR_DOUBLE_ ) || (v_typenum == api. UAR_FLOAT_ ) ||
89
+ (v_typenum == api. UAR_CDOUBLE_ ) || (v_typenum == api. UAR_CFLOAT_ )))
92
90
{
93
91
std::cout << " Found: [" << mat_typenum << " , " << v_typenum << " , "
94
92
<< r_typenum << " ]" << std::endl;
@@ -103,7 +101,7 @@ py_gemv(sycl::queue q,
103
101
char *r_typeless_ptr = result.get_data ();
104
102
105
103
sycl::event res_ev;
106
- if (v_typenum == UAR_DOUBLE ) {
104
+ if (v_typenum == api. UAR_DOUBLE_ ) {
107
105
using T = double ;
108
106
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
109
107
q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -112,7 +110,7 @@ py_gemv(sycl::queue q,
112
110
reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
113
111
res_ev = gemv_ev;
114
112
}
115
- else if (v_typenum == UAR_FLOAT ) {
113
+ else if (v_typenum == api. UAR_FLOAT_ ) {
116
114
using T = float ;
117
115
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
118
116
q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -121,7 +119,7 @@ py_gemv(sycl::queue q,
121
119
reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
122
120
res_ev = gemv_ev;
123
121
}
124
- else if (v_typenum == UAR_CDOUBLE ) {
122
+ else if (v_typenum == api. UAR_CDOUBLE_ ) {
125
123
using T = std::complex<double >;
126
124
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
127
125
q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -130,7 +128,7 @@ py_gemv(sycl::queue q,
130
128
reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
131
129
res_ev = gemv_ev;
132
130
}
133
- else if (v_typenum == UAR_CFLOAT ) {
131
+ else if (v_typenum == api. UAR_CFLOAT_ ) {
134
132
using T = std::complex<float >;
135
133
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
136
134
q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -185,21 +183,18 @@ py_sub(sycl::queue q,
185
183
throw std::runtime_error (" Vectors must have the same length" );
186
184
}
187
185
188
- if (q.get_context () != in_v1.get_queue ().get_context () ||
189
- q.get_context () != in_v2.get_queue ().get_context () ||
190
- q.get_context () != out_r.get_queue ().get_context ())
186
+ if (!dpctl::utils::queues_are_compatible (
187
+ q, {in_v1.get_queue (), in_v2.get_queue (), out_r.get_queue ()}))
191
188
{
192
189
throw std::runtime_error (
193
190
" USM allocation is not bound to the context in execution queue" );
194
191
}
195
192
196
- int in_v1_flags = in_v1.get_flags ();
197
- int in_v2_flags = in_v2.get_flags ();
198
- int out_r_flags = out_r.get_flags ();
193
+ auto &api = dpctl::detail::dpctl_capi::get ();
199
194
200
- if (!((in_v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
201
- (in_v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
202
- (out_r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))))
195
+ if (!((in_v1. is_c_contiguous () || in_v1. is_f_contiguous ( )) &&
196
+ (in_v2. is_c_contiguous () || in_v2. is_f_contiguous ( )) &&
197
+ (out_r. is_c_contiguous () || out_r. is_f_contiguous ( ))))
203
198
{
204
199
throw std::runtime_error (" Vectors must be contiguous." );
205
200
}
@@ -209,8 +204,10 @@ py_sub(sycl::queue q,
209
204
int out_r_typenum = out_r.get_typenum ();
210
205
211
206
if ((in_v2_typenum != in_v1_typenum) || (out_r_typenum != in_v1_typenum) ||
212
- !((in_v1_typenum == UAR_DOUBLE) || (in_v1_typenum == UAR_FLOAT) ||
213
- (in_v1_typenum == UAR_CDOUBLE) || (in_v1_typenum == UAR_CFLOAT)))
207
+ !((in_v1_typenum == api.UAR_DOUBLE_ ) ||
208
+ (in_v1_typenum == api.UAR_FLOAT_ ) ||
209
+ (in_v1_typenum == api.UAR_CDOUBLE_ ) ||
210
+ (in_v1_typenum == api.UAR_CFLOAT_ )))
214
211
{
215
212
throw std::runtime_error (
216
213
" Only real and complex floating point arrays are supported." );
@@ -221,22 +218,22 @@ py_sub(sycl::queue q,
221
218
char *out_r_typeless_ptr = out_r.get_data ();
222
219
223
220
sycl::event res_ev;
224
- if (out_r_typenum == UAR_DOUBLE ) {
221
+ if (out_r_typenum == api. UAR_DOUBLE_ ) {
225
222
using T = double ;
226
223
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
227
224
out_r_typeless_ptr, depends);
228
225
}
229
- else if (out_r_typenum == UAR_FLOAT ) {
226
+ else if (out_r_typenum == api. UAR_FLOAT_ ) {
230
227
using T = float ;
231
228
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
232
229
out_r_typeless_ptr, depends);
233
230
}
234
- else if (out_r_typenum == UAR_CDOUBLE ) {
231
+ else if (out_r_typenum == api. UAR_CDOUBLE_ ) {
235
232
using T = std::complex<double >;
236
233
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
237
234
out_r_typeless_ptr, depends);
238
235
}
239
- else if (out_r_typenum == UAR_CFLOAT ) {
236
+ else if (out_r_typenum == api. UAR_CFLOAT_ ) {
240
237
using T = std::complex<float >;
241
238
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
242
239
out_r_typeless_ptr, depends);
@@ -294,18 +291,15 @@ py_axpby_inplace(sycl::queue q,
294
291
throw std::runtime_error (" Vectors must have the same length" );
295
292
}
296
293
297
- if (q.get_context () != x.get_queue ().get_context () ||
298
- q.get_context () != y.get_queue ().get_context ())
294
+ if (!dpctl::utils::queues_are_compatible (q, {x.get_queue (), y.get_queue ()}))
299
295
{
300
296
throw std::runtime_error (
301
297
" USM allocation is not bound to the context in execution queue" );
302
298
}
299
+ auto &api = dpctl::detail::dpctl_capi::get ();
303
300
304
- int x_flags = x.get_flags ();
305
- int y_flags = y.get_flags ();
306
-
307
- if (!((x_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
308
- (y_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
301
+ if (!((x.is_c_contiguous () || x.is_f_contiguous ()) &&
302
+ (y.is_c_contiguous () || y.is_f_contiguous ())))
309
303
{
310
304
throw std::runtime_error (" Vectors must be contiguous." );
311
305
}
@@ -314,8 +308,8 @@ py_axpby_inplace(sycl::queue q,
314
308
int y_typenum = y.get_typenum ();
315
309
316
310
if ((x_typenum != y_typenum) ||
317
- !((x_typenum == UAR_DOUBLE ) || (x_typenum == UAR_FLOAT ) ||
318
- (x_typenum == UAR_CDOUBLE ) || (x_typenum == UAR_CFLOAT )))
311
+ !((x_typenum == api. UAR_DOUBLE_ ) || (x_typenum == api. UAR_FLOAT_ ) ||
312
+ (x_typenum == api. UAR_CDOUBLE_ ) || (x_typenum == api. UAR_CFLOAT_ )))
319
313
{
320
314
throw std::runtime_error (
321
315
" Only real and complex floating point arrays are supported." );
@@ -325,22 +319,22 @@ py_axpby_inplace(sycl::queue q,
325
319
char *y_typeless_ptr = y.get_data ();
326
320
327
321
sycl::event res_ev;
328
- if (x_typenum == UAR_DOUBLE ) {
322
+ if (x_typenum == api. UAR_DOUBLE_ ) {
329
323
using T = double ;
330
324
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
331
325
y_typeless_ptr, depends);
332
326
}
333
- else if (x_typenum == UAR_FLOAT ) {
327
+ else if (x_typenum == api. UAR_FLOAT_ ) {
334
328
using T = float ;
335
329
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
336
330
y_typeless_ptr, depends);
337
331
}
338
- else if (x_typenum == UAR_CDOUBLE ) {
332
+ else if (x_typenum == api. UAR_CDOUBLE_ ) {
339
333
using T = std::complex<double >;
340
334
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
341
335
y_typeless_ptr, depends);
342
336
}
343
- else if (x_typenum == UAR_CFLOAT ) {
337
+ else if (x_typenum == api. UAR_CFLOAT_ ) {
344
338
using T = std::complex<float >;
345
339
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
346
340
y_typeless_ptr, depends);
@@ -393,18 +387,20 @@ py::object py_norm_squared_blocking(sycl::queue q,
393
387
394
388
int r_flags = r.get_flags ();
395
389
396
- if (!(r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))) {
390
+ if (!(r. is_c_contiguous () || r. is_f_contiguous ( ))) {
397
391
throw std::runtime_error (" Vector must be contiguous." );
398
392
}
399
393
400
- if (q. get_context () != r.get_queue (). get_context ( )) {
394
+ if (! dpctl::utils::queues_are_compatible (q, { r.get_queue ()} )) {
401
395
throw std::runtime_error (
402
396
" USM allocation is not bound to the context in execution queue" );
403
397
}
404
398
399
+ auto &api = dpctl::detail::dpctl_capi::get ();
400
+
405
401
int r_typenum = r.get_typenum ();
406
- if ((r_typenum != UAR_DOUBLE ) && (r_typenum != UAR_FLOAT ) &&
407
- (r_typenum != UAR_CDOUBLE ) && (r_typenum != UAR_CFLOAT ))
402
+ if ((r_typenum != api. UAR_DOUBLE_ ) && (r_typenum != api. UAR_FLOAT_ ) &&
403
+ (r_typenum != api. UAR_CDOUBLE_ ) && (r_typenum != api. UAR_CFLOAT_ ))
408
404
{
409
405
throw std::runtime_error (
410
406
" Only real and complex floating point arrays are supported." );
@@ -413,23 +409,23 @@ py::object py_norm_squared_blocking(sycl::queue q,
413
409
const char *r_typeless_ptr = r.get_data ();
414
410
py::object res;
415
411
416
- if (r_typenum == UAR_DOUBLE ) {
412
+ if (r_typenum == api. UAR_DOUBLE_ ) {
417
413
using T = double ;
418
414
T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
419
415
res = py::float_ (n_sq);
420
416
}
421
- else if (r_typenum == UAR_FLOAT ) {
417
+ else if (r_typenum == api. UAR_FLOAT_ ) {
422
418
using T = float ;
423
419
T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
424
420
res = py::float_ (n_sq);
425
421
}
426
- else if (r_typenum == UAR_CDOUBLE ) {
422
+ else if (r_typenum == api. UAR_CDOUBLE_ ) {
427
423
using T = std::complex<double >;
428
424
double n_sq = complex_norm_squared_blocking_impl<double >(
429
425
q, n, r_typeless_ptr, depends);
430
426
res = py::float_ (n_sq);
431
427
}
432
- else if (r_typenum == UAR_CFLOAT ) {
428
+ else if (r_typenum == api. UAR_CFLOAT_ ) {
433
429
using T = std::complex<float >;
434
430
float n_sq = complex_norm_squared_blocking_impl<float >(
435
431
q, n, r_typeless_ptr, depends);
@@ -457,28 +453,27 @@ py::object py_dot_blocking(sycl::queue q,
457
453
throw std::runtime_error (" Length of vectors are not the same" );
458
454
}
459
455
460
- int v1_flags = v1.get_flags ();
461
- int v2_flags = v2.get_flags ();
462
-
463
- if (!(v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) ||
464
- !(v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)))
456
+ if (!(v1.is_c_contiguous () || v1.is_f_contiguous ()) ||
457
+ !(v2.is_c_contiguous () || v2.is_f_contiguous ()))
465
458
{
466
459
throw std::runtime_error (" Vectors must be contiguous." );
467
460
}
468
461
469
- if (q. get_context () != v1. get_queue (). get_context () ||
470
- q. get_context () != v2.get_queue (). get_context ( ))
462
+ if (! dpctl::utils::queues_are_compatible (q,
463
+ {v1. get_queue (), v2.get_queue ()} ))
471
464
{
472
465
throw std::runtime_error (
473
466
" USM allocation is not bound to the context in execution queue" );
474
467
}
475
468
469
+ auto &api = dpctl::detail::dpctl_capi::get ();
470
+
476
471
int v1_typenum = v1.get_typenum ();
477
472
int v2_typenum = v2.get_typenum ();
478
473
479
474
if ((v1_typenum != v2_typenum) ||
480
- ((v1_typenum != UAR_DOUBLE ) && (v1_typenum != UAR_FLOAT ) &&
481
- (v1_typenum != UAR_CDOUBLE ) && (v1_typenum != UAR_CFLOAT )))
475
+ ((v1_typenum != api. UAR_DOUBLE_ ) && (v1_typenum != api. UAR_FLOAT_ ) &&
476
+ (v1_typenum != api. UAR_CDOUBLE_ ) && (v1_typenum != api. UAR_CFLOAT_ )))
482
477
{
483
478
throw py::value_error (
484
479
" Data types of vectors must be the same. "
@@ -489,7 +484,7 @@ py::object py_dot_blocking(sycl::queue q,
489
484
const char *v2_typeless_ptr = v2.get_data ();
490
485
py::object res;
491
486
492
- if (v1_typenum == UAR_DOUBLE ) {
487
+ if (v1_typenum == api. UAR_DOUBLE_ ) {
493
488
using T = double ;
494
489
T *res_usm = sycl::malloc_device<T>(1 , q);
495
490
sycl::event dot_ev = oneapi::mkl::blas::row_major::dot (
@@ -500,7 +495,7 @@ py::object py_dot_blocking(sycl::queue q,
500
495
sycl::free (res_usm, q);
501
496
res = py::float_ (res_v);
502
497
}
503
- else if (v1_typenum == UAR_FLOAT ) {
498
+ else if (v1_typenum == api. UAR_FLOAT_ ) {
504
499
using T = float ;
505
500
T *res_usm = sycl::malloc_device<T>(1 , q);
506
501
sycl::event dot_ev = oneapi::mkl::blas::row_major::dot (
@@ -511,7 +506,7 @@ py::object py_dot_blocking(sycl::queue q,
511
506
sycl::free (res_usm, q);
512
507
res = py::float_ (res_v);
513
508
}
514
- else if (v1_typenum == UAR_CDOUBLE ) {
509
+ else if (v1_typenum == api. UAR_CDOUBLE_ ) {
515
510
using T = std::complex<double >;
516
511
T *res_usm = sycl::malloc_device<T>(1 , q);
517
512
sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc (
@@ -522,7 +517,7 @@ py::object py_dot_blocking(sycl::queue q,
522
517
sycl::free (res_usm, q);
523
518
res = py::cast (res_v);
524
519
}
525
- else if (v1_typenum == UAR_CFLOAT ) {
520
+ else if (v1_typenum == api. UAR_CFLOAT_ ) {
526
521
using T = std::complex<float >;
527
522
T *res_usm = sycl::malloc_device<T>(1 , q);
528
523
sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc (
@@ -563,9 +558,8 @@ int py_cg_solve(sycl::queue exec_q,
563
558
" Dimensions of the matrix and vectors are not consistent." );
564
559
}
565
560
566
- bool all_contig = (Amat.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
567
- (bvec.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
568
- (xvec.get_flags () & USM_ARRAY_C_CONTIGUOUS);
561
+ bool all_contig = (Amat.is_c_contiguous ()) && (bvec.is_c_contiguous ()) &&
562
+ (xvec.is_c_contiguous ());
569
563
if (!all_contig) {
570
564
throw py::value_error (" All inputs must be C-contiguous" );
571
565
}
@@ -578,19 +572,20 @@ int py_cg_solve(sycl::queue exec_q,
578
572
throw py::value_error (" All arrays must have the same type" );
579
573
}
580
574
581
- if (exec_q.get_context () != Amat.get_queue ().get_context () ||
582
- exec_q.get_context () != bvec.get_queue ().get_context () ||
583
- exec_q.get_context () != xvec.get_queue ().get_context ())
575
+ if (!dpctl::utils::queues_are_compatible (
576
+ exec_q, {Amat.get_queue (), bvec.get_queue (), xvec.get_queue ()}))
584
577
{
585
578
throw std::runtime_error (
586
- " USM allocations are not bound to context in execution queue" );
579
+ " USM allocation queues are not the same as the execution queue" );
587
580
}
588
581
589
582
const char *A_ch = Amat.get_data ();
590
583
const char *b_ch = bvec.get_data ();
591
584
char *x_ch = xvec.get_data ();
592
585
593
- if (A_typenum == UAR_DOUBLE) {
586
+ auto &api = dpctl::detail::dpctl_capi::get ();
587
+
588
+ if (A_typenum == api.UAR_DOUBLE_ ) {
594
589
using T = double ;
595
590
int iters = cg_solver::cg_solve<T>(
596
591
exec_q, n0, reinterpret_cast <const T *>(A_ch),
@@ -599,7 +594,7 @@ int py_cg_solve(sycl::queue exec_q,
599
594
600
595
return iters;
601
596
}
602
- else if (A_typenum == UAR_FLOAT ) {
597
+ else if (A_typenum == api. UAR_FLOAT_ ) {
603
598
using T = float ;
604
599
int iters = cg_solver::cg_solve<T>(
605
600
exec_q, n0, reinterpret_cast <const T *>(A_ch),
@@ -616,9 +611,6 @@ int py_cg_solve(sycl::queue exec_q,
616
611
617
612
PYBIND11_MODULE (_onemkl, m)
618
613
{
619
- // Import the dpctl extensions
620
- import_dpctl ();
621
-
622
614
m.def (" gemv" , &py_gemv, " Uses oneMKL to compute dot(matrix, vector)" ,
623
615
py::arg (" exec_queue" ), py::arg (" Amatrix" ), py::arg (" xvec" ),
624
616
py::arg (" resvec" ), py::arg (" depends" ) = py::list ());
0 commit comments