Skip to content

Commit caded96

Browse files
Updated extension for changes in dpctl4pybind11.hpp
1 parent 6f44d42 commit caded96

File tree

1 file changed

+65
-73
lines changed

1 file changed

+65
-73
lines changed

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,11 @@ py_gemv(sycl::queue q,
7171
"USM allocation is not bound to the context in execution queue.");
7272
}
7373

74-
int mat_flags = matrix.get_flags();
75-
int v_flags = vector.get_flags();
76-
int r_flags = result.get_flags();
74+
auto &api = dpctl::detail::dpctl_capi::get();
7775

78-
if (!((mat_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
79-
(v_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
80-
(r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
76+
if (!((matrix.is_c_contiguous()) &&
77+
(vector.is_c_contiguous() || vector.is_f_contiguous()) &&
78+
(result.is_c_contiguous() || result.is_f_contiguous())))
8179
{
8280
throw std::runtime_error("Arrays must be contiguous.");
8381
}
@@ -87,8 +85,8 @@ py_gemv(sycl::queue q,
8785
int r_typenum = result.get_typenum();
8886

8987
if ((mat_typenum != v_typenum) || (r_typenum != v_typenum) ||
90-
!((v_typenum == UAR_DOUBLE) || (v_typenum == UAR_FLOAT) ||
91-
(v_typenum == UAR_CDOUBLE) || (v_typenum == UAR_CFLOAT)))
88+
!((v_typenum == api.UAR_DOUBLE_) || (v_typenum == api.UAR_FLOAT_) ||
89+
(v_typenum == api.UAR_CDOUBLE_) || (v_typenum == api.UAR_CFLOAT_)))
9290
{
9391
std::cout << "Found: [" << mat_typenum << ", " << v_typenum << ", "
9492
<< r_typenum << "]" << std::endl;
@@ -103,7 +101,7 @@ py_gemv(sycl::queue q,
103101
char *r_typeless_ptr = result.get_data();
104102

105103
sycl::event res_ev;
106-
if (v_typenum == UAR_DOUBLE) {
104+
if (v_typenum == api.UAR_DOUBLE_) {
107105
using T = double;
108106
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
109107
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
@@ -112,7 +110,7 @@ py_gemv(sycl::queue q,
112110
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
113111
res_ev = gemv_ev;
114112
}
115-
else if (v_typenum == UAR_FLOAT) {
113+
else if (v_typenum == api.UAR_FLOAT_) {
116114
using T = float;
117115
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
118116
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
@@ -121,7 +119,7 @@ py_gemv(sycl::queue q,
121119
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
122120
res_ev = gemv_ev;
123121
}
124-
else if (v_typenum == UAR_CDOUBLE) {
122+
else if (v_typenum == api.UAR_CDOUBLE_) {
125123
using T = std::complex<double>;
126124
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
127125
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
@@ -130,7 +128,7 @@ py_gemv(sycl::queue q,
130128
reinterpret_cast<T *>(r_typeless_ptr), 1, depends);
131129
res_ev = gemv_ev;
132130
}
133-
else if (v_typenum == UAR_CFLOAT) {
131+
else if (v_typenum == api.UAR_CFLOAT_) {
134132
using T = std::complex<float>;
135133
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
136134
q, oneapi::mkl::transpose::nontrans, n, m, T(1),
@@ -185,21 +183,18 @@ py_sub(sycl::queue q,
185183
throw std::runtime_error("Vectors must have the same length");
186184
}
187185

188-
if (q.get_context() != in_v1.get_queue().get_context() ||
189-
q.get_context() != in_v2.get_queue().get_context() ||
190-
q.get_context() != out_r.get_queue().get_context())
186+
if (!dpctl::utils::queues_are_compatible(
187+
q, {in_v1.get_queue(), in_v2.get_queue(), out_r.get_queue()}))
191188
{
192189
throw std::runtime_error(
193190
"USM allocation is not bound to the context in execution queue");
194191
}
195192

196-
int in_v1_flags = in_v1.get_flags();
197-
int in_v2_flags = in_v2.get_flags();
198-
int out_r_flags = out_r.get_flags();
193+
auto &api = dpctl::detail::dpctl_capi::get();
199194

200-
if (!((in_v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
201-
(in_v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
202-
(out_r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
195+
if (!((in_v1.is_c_contiguous() || in_v1.is_f_contiguous()) &&
196+
(in_v2.is_c_contiguous() || in_v2.is_f_contiguous()) &&
197+
(out_r.is_c_contiguous() || out_r.is_f_contiguous())))
203198
{
204199
throw std::runtime_error("Vectors must be contiguous.");
205200
}
@@ -209,8 +204,10 @@ py_sub(sycl::queue q,
209204
int out_r_typenum = out_r.get_typenum();
210205

211206
if ((in_v2_typenum != in_v1_typenum) || (out_r_typenum != in_v1_typenum) ||
212-
!((in_v1_typenum == UAR_DOUBLE) || (in_v1_typenum == UAR_FLOAT) ||
213-
(in_v1_typenum == UAR_CDOUBLE) || (in_v1_typenum == UAR_CFLOAT)))
207+
!((in_v1_typenum == api.UAR_DOUBLE_) ||
208+
(in_v1_typenum == api.UAR_FLOAT_) ||
209+
(in_v1_typenum == api.UAR_CDOUBLE_) ||
210+
(in_v1_typenum == api.UAR_CFLOAT_)))
214211
{
215212
throw std::runtime_error(
216213
"Only real and complex floating point arrays are supported.");
@@ -221,22 +218,22 @@ py_sub(sycl::queue q,
221218
char *out_r_typeless_ptr = out_r.get_data();
222219

223220
sycl::event res_ev;
224-
if (out_r_typenum == UAR_DOUBLE) {
221+
if (out_r_typenum == api.UAR_DOUBLE_) {
225222
using T = double;
226223
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
227224
out_r_typeless_ptr, depends);
228225
}
229-
else if (out_r_typenum == UAR_FLOAT) {
226+
else if (out_r_typenum == api.UAR_FLOAT_) {
230227
using T = float;
231228
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
232229
out_r_typeless_ptr, depends);
233230
}
234-
else if (out_r_typenum == UAR_CDOUBLE) {
231+
else if (out_r_typenum == api.UAR_CDOUBLE_) {
235232
using T = std::complex<double>;
236233
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
237234
out_r_typeless_ptr, depends);
238235
}
239-
else if (out_r_typenum == UAR_CFLOAT) {
236+
else if (out_r_typenum == api.UAR_CFLOAT_) {
240237
using T = std::complex<float>;
241238
res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
242239
out_r_typeless_ptr, depends);
@@ -294,18 +291,15 @@ py_axpby_inplace(sycl::queue q,
294291
throw std::runtime_error("Vectors must have the same length");
295292
}
296293

297-
if (q.get_context() != x.get_queue().get_context() ||
298-
q.get_context() != y.get_queue().get_context())
294+
if (!dpctl::utils::queues_are_compatible(q, {x.get_queue(), y.get_queue()}))
299295
{
300296
throw std::runtime_error(
301297
"USM allocation is not bound to the context in execution queue");
302298
}
299+
auto &api = dpctl::detail::dpctl_capi::get();
303300

304-
int x_flags = x.get_flags();
305-
int y_flags = y.get_flags();
306-
307-
if (!((x_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
308-
(y_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
301+
if (!((x.is_c_contiguous() || x.is_f_contiguous()) &&
302+
(y.is_c_contiguous() || y.is_f_contiguous())))
309303
{
310304
throw std::runtime_error("Vectors must be contiguous.");
311305
}
@@ -314,8 +308,8 @@ py_axpby_inplace(sycl::queue q,
314308
int y_typenum = y.get_typenum();
315309

316310
if ((x_typenum != y_typenum) ||
317-
!((x_typenum == UAR_DOUBLE) || (x_typenum == UAR_FLOAT) ||
318-
(x_typenum == UAR_CDOUBLE) || (x_typenum == UAR_CFLOAT)))
311+
!((x_typenum == api.UAR_DOUBLE_) || (x_typenum == api.UAR_FLOAT_) ||
312+
(x_typenum == api.UAR_CDOUBLE_) || (x_typenum == api.UAR_CFLOAT_)))
319313
{
320314
throw std::runtime_error(
321315
"Only real and complex floating point arrays are supported.");
@@ -325,22 +319,22 @@ py_axpby_inplace(sycl::queue q,
325319
char *y_typeless_ptr = y.get_data();
326320

327321
sycl::event res_ev;
328-
if (x_typenum == UAR_DOUBLE) {
322+
if (x_typenum == api.UAR_DOUBLE_) {
329323
using T = double;
330324
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
331325
y_typeless_ptr, depends);
332326
}
333-
else if (x_typenum == UAR_FLOAT) {
327+
else if (x_typenum == api.UAR_FLOAT_) {
334328
using T = float;
335329
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
336330
y_typeless_ptr, depends);
337331
}
338-
else if (x_typenum == UAR_CDOUBLE) {
332+
else if (x_typenum == api.UAR_CDOUBLE_) {
339333
using T = std::complex<double>;
340334
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
341335
y_typeless_ptr, depends);
342336
}
343-
else if (x_typenum == UAR_CFLOAT) {
337+
else if (x_typenum == api.UAR_CFLOAT_) {
344338
using T = std::complex<float>;
345339
res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
346340
y_typeless_ptr, depends);
@@ -393,18 +387,20 @@ py::object py_norm_squared_blocking(sycl::queue q,
393387

394388
int r_flags = r.get_flags();
395389

396-
if (!(r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))) {
390+
if (!(r.is_c_contiguous() || r.is_f_contiguous())) {
397391
throw std::runtime_error("Vector must be contiguous.");
398392
}
399393

400-
if (q.get_context() != r.get_queue().get_context()) {
394+
if (!dpctl::utils::queues_are_compatible(q, {r.get_queue()})) {
401395
throw std::runtime_error(
402396
"USM allocation is not bound to the context in execution queue");
403397
}
404398

399+
auto &api = dpctl::detail::dpctl_capi::get();
400+
405401
int r_typenum = r.get_typenum();
406-
if ((r_typenum != UAR_DOUBLE) && (r_typenum != UAR_FLOAT) &&
407-
(r_typenum != UAR_CDOUBLE) && (r_typenum != UAR_CFLOAT))
402+
if ((r_typenum != api.UAR_DOUBLE_) && (r_typenum != api.UAR_FLOAT_) &&
403+
(r_typenum != api.UAR_CDOUBLE_) && (r_typenum != api.UAR_CFLOAT_))
408404
{
409405
throw std::runtime_error(
410406
"Only real and complex floating point arrays are supported.");
@@ -413,23 +409,23 @@ py::object py_norm_squared_blocking(sycl::queue q,
413409
const char *r_typeless_ptr = r.get_data();
414410
py::object res;
415411

416-
if (r_typenum == UAR_DOUBLE) {
412+
if (r_typenum == api.UAR_DOUBLE_) {
417413
using T = double;
418414
T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
419415
res = py::float_(n_sq);
420416
}
421-
else if (r_typenum == UAR_FLOAT) {
417+
else if (r_typenum == api.UAR_FLOAT_) {
422418
using T = float;
423419
T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
424420
res = py::float_(n_sq);
425421
}
426-
else if (r_typenum == UAR_CDOUBLE) {
422+
else if (r_typenum == api.UAR_CDOUBLE_) {
427423
using T = std::complex<double>;
428424
double n_sq = complex_norm_squared_blocking_impl<double>(
429425
q, n, r_typeless_ptr, depends);
430426
res = py::float_(n_sq);
431427
}
432-
else if (r_typenum == UAR_CFLOAT) {
428+
else if (r_typenum == api.UAR_CFLOAT_) {
433429
using T = std::complex<float>;
434430
float n_sq = complex_norm_squared_blocking_impl<float>(
435431
q, n, r_typeless_ptr, depends);
@@ -457,28 +453,27 @@ py::object py_dot_blocking(sycl::queue q,
457453
throw std::runtime_error("Length of vectors are not the same");
458454
}
459455

460-
int v1_flags = v1.get_flags();
461-
int v2_flags = v2.get_flags();
462-
463-
if (!(v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) ||
464-
!(v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)))
456+
if (!(v1.is_c_contiguous() || v1.is_f_contiguous()) ||
457+
!(v2.is_c_contiguous() || v2.is_f_contiguous()))
465458
{
466459
throw std::runtime_error("Vectors must be contiguous.");
467460
}
468461

469-
if (q.get_context() != v1.get_queue().get_context() ||
470-
q.get_context() != v2.get_queue().get_context())
462+
if (!dpctl::utils::queues_are_compatible(q,
463+
{v1.get_queue(), v2.get_queue()}))
471464
{
472465
throw std::runtime_error(
473466
"USM allocation is not bound to the context in execution queue");
474467
}
475468

469+
auto &api = dpctl::detail::dpctl_capi::get();
470+
476471
int v1_typenum = v1.get_typenum();
477472
int v2_typenum = v2.get_typenum();
478473

479474
if ((v1_typenum != v2_typenum) ||
480-
((v1_typenum != UAR_DOUBLE) && (v1_typenum != UAR_FLOAT) &&
481-
(v1_typenum != UAR_CDOUBLE) && (v1_typenum != UAR_CFLOAT)))
475+
((v1_typenum != api.UAR_DOUBLE_) && (v1_typenum != api.UAR_FLOAT_) &&
476+
(v1_typenum != api.UAR_CDOUBLE_) && (v1_typenum != api.UAR_CFLOAT_)))
482477
{
483478
throw py::value_error(
484479
"Data types of vectors must be the same. "
@@ -489,7 +484,7 @@ py::object py_dot_blocking(sycl::queue q,
489484
const char *v2_typeless_ptr = v2.get_data();
490485
py::object res;
491486

492-
if (v1_typenum == UAR_DOUBLE) {
487+
if (v1_typenum == api.UAR_DOUBLE_) {
493488
using T = double;
494489
T *res_usm = sycl::malloc_device<T>(1, q);
495490
sycl::event dot_ev = oneapi::mkl::blas::row_major::dot(
@@ -500,7 +495,7 @@ py::object py_dot_blocking(sycl::queue q,
500495
sycl::free(res_usm, q);
501496
res = py::float_(res_v);
502497
}
503-
else if (v1_typenum == UAR_FLOAT) {
498+
else if (v1_typenum == api.UAR_FLOAT_) {
504499
using T = float;
505500
T *res_usm = sycl::malloc_device<T>(1, q);
506501
sycl::event dot_ev = oneapi::mkl::blas::row_major::dot(
@@ -511,7 +506,7 @@ py::object py_dot_blocking(sycl::queue q,
511506
sycl::free(res_usm, q);
512507
res = py::float_(res_v);
513508
}
514-
else if (v1_typenum == UAR_CDOUBLE) {
509+
else if (v1_typenum == api.UAR_CDOUBLE_) {
515510
using T = std::complex<double>;
516511
T *res_usm = sycl::malloc_device<T>(1, q);
517512
sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc(
@@ -522,7 +517,7 @@ py::object py_dot_blocking(sycl::queue q,
522517
sycl::free(res_usm, q);
523518
res = py::cast(res_v);
524519
}
525-
else if (v1_typenum == UAR_CFLOAT) {
520+
else if (v1_typenum == api.UAR_CFLOAT_) {
526521
using T = std::complex<float>;
527522
T *res_usm = sycl::malloc_device<T>(1, q);
528523
sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc(
@@ -563,9 +558,8 @@ int py_cg_solve(sycl::queue exec_q,
563558
"Dimensions of the matrix and vectors are not consistent.");
564559
}
565560

566-
bool all_contig = (Amat.get_flags() & USM_ARRAY_C_CONTIGUOUS) &&
567-
(bvec.get_flags() & USM_ARRAY_C_CONTIGUOUS) &&
568-
(xvec.get_flags() & USM_ARRAY_C_CONTIGUOUS);
561+
bool all_contig = (Amat.is_c_contiguous()) && (bvec.is_c_contiguous()) &&
562+
(xvec.is_c_contiguous());
569563
if (!all_contig) {
570564
throw py::value_error("All inputs must be C-contiguous");
571565
}
@@ -578,19 +572,20 @@ int py_cg_solve(sycl::queue exec_q,
578572
throw py::value_error("All arrays must have the same type");
579573
}
580574

581-
if (exec_q.get_context() != Amat.get_queue().get_context() ||
582-
exec_q.get_context() != bvec.get_queue().get_context() ||
583-
exec_q.get_context() != xvec.get_queue().get_context())
575+
if (!dpctl::utils::queues_are_compatible(
576+
exec_q, {Amat.get_queue(), bvec.get_queue(), xvec.get_queue()}))
584577
{
585578
throw std::runtime_error(
586-
"USM allocations are not bound to context in execution queue");
579+
"USM allocation queues are not the same as the execution queue");
587580
}
588581

589582
const char *A_ch = Amat.get_data();
590583
const char *b_ch = bvec.get_data();
591584
char *x_ch = xvec.get_data();
592585

593-
if (A_typenum == UAR_DOUBLE) {
586+
auto &api = dpctl::detail::dpctl_capi::get();
587+
588+
if (A_typenum == api.UAR_DOUBLE_) {
594589
using T = double;
595590
int iters = cg_solver::cg_solve<T>(
596591
exec_q, n0, reinterpret_cast<const T *>(A_ch),
@@ -599,7 +594,7 @@ int py_cg_solve(sycl::queue exec_q,
599594

600595
return iters;
601596
}
602-
else if (A_typenum == UAR_FLOAT) {
597+
else if (A_typenum == api.UAR_FLOAT_) {
603598
using T = float;
604599
int iters = cg_solver::cg_solve<T>(
605600
exec_q, n0, reinterpret_cast<const T *>(A_ch),
@@ -616,9 +611,6 @@ int py_cg_solve(sycl::queue exec_q,
616611

617612
PYBIND11_MODULE(_onemkl, m)
618613
{
619-
// Import the dpctl extensions
620-
import_dpctl();
621-
622614
m.def("gemv", &py_gemv, "Uses oneMKL to compute dot(matrix, vector)",
623615
py::arg("exec_queue"), py::arg("Amatrix"), py::arg("xvec"),
624616
py::arg("resvec"), py::arg("depends") = py::list());

0 commit comments

Comments
 (0)