Skip to content

Commit a496d53

Browse files
Introduced dpctl::tensor::is_c_contiguous and is_f_contiguous and used it
1 parent db68b36 commit a496d53

File tree

2 files changed

+43
-47
lines changed

2 files changed

+43
-47
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,18 @@ class usm_ndarray : public py::object
563563

564564
return UsmNDArray_GetElementSize(raw_ar);
565565
}
566+
567+
bool is_c_contiguous() const
568+
{
569+
int flags = this->get_flags();
570+
return static_cast<bool>(flags & USM_ARRAY_C_CONTIGUOUS);
571+
}
572+
573+
bool is_f_contiguous() const
574+
{
575+
int flags = this->get_flags();
576+
return static_cast<bool>(flags & USM_ARRAY_F_CONTIGUOUS);
577+
}
566578
};
567579

568580
} // end namespace tensor

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,16 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
322322
throw py::value_error("Arrays index overlapping segments of memory");
323323
}
324324

325-
int src_flags = src.get_flags();
326-
int dst_flags = dst.get_flags();
325+
bool is_src_c_contig = src.is_c_contiguous();
326+
bool is_src_f_contig = src.is_f_contiguous();
327+
328+
bool is_dst_c_contig = dst.is_c_contiguous();
329+
bool is_dst_f_contig = dst.is_f_contiguous();
327330

328331
// check for applicability of special cases:
329332
// (same type && (both C-contiguous || both F-contiguous)
330-
bool both_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) &&
331-
(dst_flags & USM_ARRAY_C_CONTIGUOUS));
332-
bool both_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) &&
333-
(dst_flags & USM_ARRAY_F_CONTIGUOUS));
333+
bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
334+
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
334335
if (both_c_contig || both_f_contig) {
335336
if (src_type_id == dst_type_id) {
336337

@@ -360,12 +361,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
360361
int nd = src_nd;
361362
const py::ssize_t *shape = src_shape;
362363

363-
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
364-
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
365-
366-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
367-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
368-
369364
constexpr py::ssize_t src_itemsize = 1; // in elements
370365
constexpr py::ssize_t dst_itemsize = 1; // in elements
371366

@@ -576,14 +571,13 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
576571

577572
const py::ssize_t *src_strides = src.get_strides_raw();
578573
if (src_strides == nullptr) {
579-
int src_flags = src.get_flags();
580-
if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
574+
if (src.is_c_contiguous()) {
581575
const auto &src_contig_strides =
582576
c_contiguous_strides(src_nd, src_shape);
583577
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
584578
packed_host_shapes_strides_shp->begin() + src_nd);
585579
}
586-
else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
580+
else if (src.is_f_contiguous()) {
587581
const auto &src_contig_strides =
588582
f_contiguous_strides(src_nd, src_shape);
589583
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
@@ -602,15 +596,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
602596

603597
const py::ssize_t *dst_strides = dst.get_strides_raw();
604598
if (dst_strides == nullptr) {
605-
int dst_flags = dst.get_flags();
606-
if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
599+
if (dst.is_c_contiguous()) {
607600
const auto &dst_contig_strides =
608601
c_contiguous_strides(dst_nd, dst_shape);
609602
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
610603
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
611604
dst_nd);
612605
}
613-
else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
606+
else if (dst.is_f_contiguous()) {
614607
const auto &dst_contig_strides =
615608
f_contiguous_strides(dst_nd, dst_shape);
616609
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
@@ -744,14 +737,13 @@ void copy_numpy_ndarray_into_usm_ndarray(
744737
char *dst_data = dst.get_data();
745738

746739
int src_flags = npy_src.flags();
747-
int dst_flags = dst.get_flags();
748740

749741
// check for applicability of special cases:
750742
// (same type && (both C-contiguous || both F-contiguous)
751-
bool both_c_contig = ((src_flags & py::array::c_style) &&
752-
(dst_flags & USM_ARRAY_C_CONTIGUOUS));
753-
bool both_f_contig = ((src_flags & py::array::f_style) &&
754-
(dst_flags & USM_ARRAY_F_CONTIGUOUS));
743+
bool both_c_contig =
744+
((src_flags & py::array::c_style) && dst.is_c_contiguous());
745+
bool both_f_contig =
746+
((src_flags & py::array::f_style) && dst.is_f_contiguous());
755747
if (both_c_contig || both_f_contig) {
756748
if (src_type_id == dst_type_id) {
757749
int src_elem_size = npy_src.itemsize();
@@ -791,8 +783,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
791783
bool is_src_c_contig = ((src_flags & py::array::c_style) != 0);
792784
bool is_src_f_contig = ((src_flags & py::array::f_style) != 0);
793785

794-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
795-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
786+
bool is_dst_c_contig = dst.is_c_contiguous();
787+
bool is_dst_f_contig = dst.is_f_contiguous();
796788

797789
// all args except itemsizes and is_?_contig bools can be modified by
798790
// reference
@@ -906,16 +898,15 @@ usm_ndarray_linear_sequence_step(py::object start,
906898
"usm_ndarray_linspace: Expecting 1D array to populate");
907899
}
908900

909-
int flags = dst.get_flags();
910-
if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
901+
if (!dst.is_c_contiguous()) {
911902
throw py::value_error(
912903
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
913904
}
914905

915906
sycl::queue dst_q = dst.get_queue();
916-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
907+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
917908
throw py::value_error(
918-
"Execution queue context is not the same as allocation context");
909+
"Execution queue is not compatible with the allocation queue");
919910
}
920911

921912
int dst_typenum = dst.get_typenum();
@@ -955,14 +946,13 @@ usm_ndarray_linear_sequence_affine(py::object start,
955946
"usm_ndarray_linspace: Expecting 1D array to populate");
956947
}
957948

958-
int flags = dst.get_flags();
959-
if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
949+
if (!dst.is_c_contiguous()) {
960950
throw py::value_error(
961951
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
962952
}
963953

964954
sycl::queue dst_q = dst.get_queue();
965-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
955+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
966956
throw py::value_error(
967957
"Execution queue context is not the same as allocation context");
968958
}
@@ -1010,12 +1000,10 @@ usm_ndarray_full(py::object py_value,
10101000
return std::make_pair(sycl::event(), sycl::event());
10111001
}
10121002

1013-
int dst_flags = dst.get_flags();
1014-
10151003
sycl::queue dst_q = dst.get_queue();
1016-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
1004+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
10171005
throw py::value_error(
1018-
"Execution queue context is not the same as allocation context");
1006+
"Execution queue is not compatible with the allocation queue");
10191007
}
10201008

10211009
int dst_typenum = dst.get_typenum();
@@ -1024,9 +1012,7 @@ usm_ndarray_full(py::object py_value,
10241012
char *dst_data = dst.get_data();
10251013
sycl::event full_event;
10261014

1027-
if (dst_nelems == 1 || (dst_flags & USM_ARRAY_C_CONTIGUOUS) ||
1028-
(dst_flags & USM_ARRAY_F_CONTIGUOUS))
1029-
{
1015+
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
10301016
auto fn = full_contig_dispatch_vector[dst_typeid];
10311017

10321018
sycl::event full_contig_event =
@@ -1079,8 +1065,8 @@ eye(py::ssize_t k,
10791065
return std::make_pair(sycl::event{}, sycl::event{});
10801066
}
10811067

1082-
bool is_dst_c_contig = ((dst.get_flags() & USM_ARRAY_C_CONTIGUOUS) != 0);
1083-
bool is_dst_f_contig = ((dst.get_flags() & USM_ARRAY_F_CONTIGUOUS) != 0);
1068+
bool is_dst_c_contig = dst.is_c_contiguous();
1069+
bool is_dst_f_contig = dst.is_f_contiguous();
10841070
if (!is_dst_c_contig && !is_dst_f_contig) {
10851071
throw py::value_error("USM array is not contiguous");
10861072
}
@@ -1203,9 +1189,8 @@ tri(sycl::queue &exec_q,
12031189
using shT = std::vector<py::ssize_t>;
12041190
shT src_strides(src_nd);
12051191

1206-
int src_flags = src.get_flags();
1207-
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
1208-
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
1192+
bool is_src_c_contig = src.is_c_contiguous();
1193+
bool is_src_f_contig = src.is_f_contiguous();
12091194

12101195
const py::ssize_t *src_strides_raw = src.get_strides_raw();
12111196
if (src_strides_raw == nullptr) {
@@ -1227,9 +1212,8 @@ tri(sycl::queue &exec_q,
12271212

12281213
shT dst_strides(src_nd);
12291214

1230-
int dst_flags = dst.get_flags();
1231-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
1232-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
1215+
bool is_dst_c_contig = dst.is_c_contiguous();
1216+
bool is_dst_f_contig = dst.is_f_contiguous();
12331217

12341218
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
12351219
if (dst_strides_raw == nullptr) {

0 commit comments

Comments
 (0)