Skip to content

Commit 798ecca

Browse files
Merge pull request #931 from IntelPython/cleanup-tensor-step2
[cleanup/tensor, part 2] Modularized tests for contiguity, retrieval of PyUSMArrayObject* and removed use of a global variable.
2 parents db68b36 + 7d2ab88 commit 798ecca

File tree

3 files changed

+69
-73
lines changed

3 files changed

+69
-73
lines changed

dpctl/.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*.so
2-
*.cpp
2+
_*.cpp
33
*.cxx
44
*.c
55
*.h

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,7 @@ class usm_ndarray : public py::object
439439

440440
char *get_data() const
441441
{
442-
PyObject *raw_o = this->ptr();
443-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
442+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
444443

445444
return UsmNDArray_GetData(raw_ar);
446445
}
@@ -452,16 +451,14 @@ class usm_ndarray : public py::object
452451

453452
int get_ndim() const
454453
{
455-
PyObject *raw_o = this->ptr();
456-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
454+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
457455

458456
return UsmNDArray_GetNDim(raw_ar);
459457
}
460458

461459
const py::ssize_t *get_shape_raw() const
462460
{
463-
PyObject *raw_o = this->ptr();
464-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
461+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
465462

466463
return UsmNDArray_GetShape(raw_ar);
467464
}
@@ -474,16 +471,14 @@ class usm_ndarray : public py::object
474471

475472
const py::ssize_t *get_strides_raw() const
476473
{
477-
PyObject *raw_o = this->ptr();
478-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
474+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
479475

480476
return UsmNDArray_GetStrides(raw_ar);
481477
}
482478

483479
py::ssize_t get_size() const
484480
{
485-
PyObject *raw_o = this->ptr();
486-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
481+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
487482

488483
int ndim = UsmNDArray_GetNDim(raw_ar);
489484
const py::ssize_t *shape = UsmNDArray_GetShape(raw_ar);
@@ -499,8 +494,7 @@ class usm_ndarray : public py::object
499494

500495
std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
501496
{
502-
PyObject *raw_o = this->ptr();
503-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
497+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
504498

505499
int nd = UsmNDArray_GetNDim(raw_ar);
506500
const py::ssize_t *shape = UsmNDArray_GetShape(raw_ar);
@@ -533,36 +527,50 @@ class usm_ndarray : public py::object
533527

534528
sycl::queue get_queue() const
535529
{
536-
PyObject *raw_o = this->ptr();
537-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
530+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
538531

539532
DPCTLSyclQueueRef QRef = UsmNDArray_GetQueueRef(raw_ar);
540533
return *(reinterpret_cast<sycl::queue *>(QRef));
541534
}
542535

543536
int get_typenum() const
544537
{
545-
PyObject *raw_o = this->ptr();
546-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
538+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
547539

548540
return UsmNDArray_GetTypenum(raw_ar);
549541
}
550542

551543
int get_flags() const
552544
{
553-
PyObject *raw_o = this->ptr();
554-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
545+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
555546

556547
return UsmNDArray_GetFlags(raw_ar);
557548
}
558549

559550
int get_elemsize() const
560551
{
561-
PyObject *raw_o = this->ptr();
562-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
552+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
563553

564554
return UsmNDArray_GetElementSize(raw_ar);
565555
}
556+
557+
bool is_c_contiguous() const
558+
{
559+
int flags = this->get_flags();
560+
return static_cast<bool>(flags & USM_ARRAY_C_CONTIGUOUS);
561+
}
562+
563+
bool is_f_contiguous() const
564+
{
565+
int flags = this->get_flags();
566+
return static_cast<bool>(flags & USM_ARRAY_F_CONTIGUOUS);
567+
}
568+
569+
private:
570+
PyUSMArrayObject *usm_array_ptr() const
571+
{
572+
return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
573+
}
566574
};
567575

568576
} // end namespace tensor

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 40 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242

4343
namespace py = pybind11;
4444

45-
static dpctl::tensor::detail::usm_ndarray_types array_types;
46-
4745
namespace
4846
{
4947

@@ -301,6 +299,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
301299
int src_typenum = src.get_typenum();
302300
int dst_typenum = dst.get_typenum();
303301

302+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
304303
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
305304
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
306305

@@ -322,15 +321,16 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
322321
throw py::value_error("Arrays index overlapping segments of memory");
323322
}
324323

325-
int src_flags = src.get_flags();
326-
int dst_flags = dst.get_flags();
324+
bool is_src_c_contig = src.is_c_contiguous();
325+
bool is_src_f_contig = src.is_f_contiguous();
326+
327+
bool is_dst_c_contig = dst.is_c_contiguous();
328+
bool is_dst_f_contig = dst.is_f_contiguous();
327329

328330
// check for applicability of special cases:
329331
// (same type && (both C-contiguous || both F-contiguous)
330-
bool both_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) &&
331-
(dst_flags & USM_ARRAY_C_CONTIGUOUS));
332-
bool both_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) &&
333-
(dst_flags & USM_ARRAY_F_CONTIGUOUS));
332+
bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
333+
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
334334
if (both_c_contig || both_f_contig) {
335335
if (src_type_id == dst_type_id) {
336336

@@ -360,12 +360,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
360360
int nd = src_nd;
361361
const py::ssize_t *shape = src_shape;
362362

363-
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
364-
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
365-
366-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
367-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
368-
369363
constexpr py::ssize_t src_itemsize = 1; // in elements
370364
constexpr py::ssize_t dst_itemsize = 1; // in elements
371365

@@ -550,6 +544,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
550544
const py::ssize_t *src_shape = src.get_shape_raw();
551545
const py::ssize_t *dst_shape = dst.get_shape_raw();
552546

547+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
553548
int type_id = array_types.typenum_to_lookup_id(src_typenum);
554549

555550
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
@@ -576,14 +571,13 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
576571

577572
const py::ssize_t *src_strides = src.get_strides_raw();
578573
if (src_strides == nullptr) {
579-
int src_flags = src.get_flags();
580-
if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
574+
if (src.is_c_contiguous()) {
581575
const auto &src_contig_strides =
582576
c_contiguous_strides(src_nd, src_shape);
583577
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
584578
packed_host_shapes_strides_shp->begin() + src_nd);
585579
}
586-
else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
580+
else if (src.is_f_contiguous()) {
587581
const auto &src_contig_strides =
588582
f_contiguous_strides(src_nd, src_shape);
589583
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
@@ -602,15 +596,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
602596

603597
const py::ssize_t *dst_strides = dst.get_strides_raw();
604598
if (dst_strides == nullptr) {
605-
int dst_flags = dst.get_flags();
606-
if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
599+
if (dst.is_c_contiguous()) {
607600
const auto &dst_contig_strides =
608601
c_contiguous_strides(dst_nd, dst_shape);
609602
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
610603
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
611604
dst_nd);
612605
}
613-
else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
606+
else if (dst.is_f_contiguous()) {
614607
const auto &dst_contig_strides =
615608
f_contiguous_strides(dst_nd, dst_shape);
616609
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
@@ -736,6 +729,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
736729
py::detail::array_descriptor_proxy(npy_src.dtype().ptr())->type_num;
737730
int dst_typenum = dst.get_typenum();
738731

732+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
739733
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
740734
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
741735

@@ -744,14 +738,13 @@ void copy_numpy_ndarray_into_usm_ndarray(
744738
char *dst_data = dst.get_data();
745739

746740
int src_flags = npy_src.flags();
747-
int dst_flags = dst.get_flags();
748741

749742
// check for applicability of special cases:
750743
// (same type && (both C-contiguous || both F-contiguous)
751-
bool both_c_contig = ((src_flags & py::array::c_style) &&
752-
(dst_flags & USM_ARRAY_C_CONTIGUOUS));
753-
bool both_f_contig = ((src_flags & py::array::f_style) &&
754-
(dst_flags & USM_ARRAY_F_CONTIGUOUS));
744+
bool both_c_contig =
745+
((src_flags & py::array::c_style) && dst.is_c_contiguous());
746+
bool both_f_contig =
747+
((src_flags & py::array::f_style) && dst.is_f_contiguous());
755748
if (both_c_contig || both_f_contig) {
756749
if (src_type_id == dst_type_id) {
757750
int src_elem_size = npy_src.itemsize();
@@ -791,8 +784,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
791784
bool is_src_c_contig = ((src_flags & py::array::c_style) != 0);
792785
bool is_src_f_contig = ((src_flags & py::array::f_style) != 0);
793786

794-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
795-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
787+
bool is_dst_c_contig = dst.is_c_contiguous();
788+
bool is_dst_f_contig = dst.is_f_contiguous();
796789

797790
// all args except itemsizes and is_?_contig bools can be modified by
798791
// reference
@@ -906,18 +899,18 @@ usm_ndarray_linear_sequence_step(py::object start,
906899
"usm_ndarray_linspace: Expecting 1D array to populate");
907900
}
908901

909-
int flags = dst.get_flags();
910-
if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
902+
if (!dst.is_c_contiguous()) {
911903
throw py::value_error(
912904
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
913905
}
914906

915907
sycl::queue dst_q = dst.get_queue();
916-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
908+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
917909
throw py::value_error(
918-
"Execution queue context is not the same as allocation context");
910+
"Execution queue is not compatible with the allocation queue");
919911
}
920912

913+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
921914
int dst_typenum = dst.get_typenum();
922915
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
923916

@@ -955,18 +948,18 @@ usm_ndarray_linear_sequence_affine(py::object start,
955948
"usm_ndarray_linspace: Expecting 1D array to populate");
956949
}
957950

958-
int flags = dst.get_flags();
959-
if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
951+
if (!dst.is_c_contiguous()) {
960952
throw py::value_error(
961953
"usm_ndarray_linspace: Non-contiguous arrays are not supported");
962954
}
963955

964956
sycl::queue dst_q = dst.get_queue();
965-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
957+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
966958
throw py::value_error(
967959
"Execution queue context is not the same as allocation context");
968960
}
969961

962+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
970963
int dst_typenum = dst.get_typenum();
971964
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
972965

@@ -1010,23 +1003,20 @@ usm_ndarray_full(py::object py_value,
10101003
return std::make_pair(sycl::event(), sycl::event());
10111004
}
10121005

1013-
int dst_flags = dst.get_flags();
1014-
10151006
sycl::queue dst_q = dst.get_queue();
1016-
if (dst_q != exec_q && dst_q.get_context() != exec_q.get_context()) {
1007+
if (!dpctl::utils::queues_are_compatible(exec_q, {dst_q})) {
10171008
throw py::value_error(
1018-
"Execution queue context is not the same as allocation context");
1009+
"Execution queue is not compatible with the allocation queue");
10191010
}
10201011

1012+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
10211013
int dst_typenum = dst.get_typenum();
10221014
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
10231015

10241016
char *dst_data = dst.get_data();
10251017
sycl::event full_event;
10261018

1027-
if (dst_nelems == 1 || (dst_flags & USM_ARRAY_C_CONTIGUOUS) ||
1028-
(dst_flags & USM_ARRAY_F_CONTIGUOUS))
1029-
{
1019+
if (dst_nelems == 1 || dst.is_c_contiguous() || dst.is_f_contiguous()) {
10301020
auto fn = full_contig_dispatch_vector[dst_typeid];
10311021

10321022
sycl::event full_contig_event =
@@ -1068,6 +1058,7 @@ eye(py::ssize_t k,
10681058
"allocation queue");
10691059
}
10701060

1061+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
10711062
int dst_typenum = dst.get_typenum();
10721063
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
10731064

@@ -1079,8 +1070,8 @@ eye(py::ssize_t k,
10791070
return std::make_pair(sycl::event{}, sycl::event{});
10801071
}
10811072

1082-
bool is_dst_c_contig = ((dst.get_flags() & USM_ARRAY_C_CONTIGUOUS) != 0);
1083-
bool is_dst_f_contig = ((dst.get_flags() & USM_ARRAY_F_CONTIGUOUS) != 0);
1073+
bool is_dst_c_contig = dst.is_c_contiguous();
1074+
bool is_dst_f_contig = dst.is_f_contiguous();
10841075
if (!is_dst_c_contig && !is_dst_f_contig) {
10851076
throw py::value_error("USM array is not contiguous");
10861077
}
@@ -1182,6 +1173,8 @@ tri(sycl::queue &exec_q,
11821173
throw py::value_error("Arrays index overlapping segments of memory");
11831174
}
11841175

1176+
auto array_types = dpctl::tensor::detail::usm_ndarray_types::get();
1177+
11851178
int src_typenum = src.get_typenum();
11861179
int dst_typenum = dst.get_typenum();
11871180
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
@@ -1203,9 +1196,8 @@ tri(sycl::queue &exec_q,
12031196
using shT = std::vector<py::ssize_t>;
12041197
shT src_strides(src_nd);
12051198

1206-
int src_flags = src.get_flags();
1207-
bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
1208-
bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
1199+
bool is_src_c_contig = src.is_c_contiguous();
1200+
bool is_src_f_contig = src.is_f_contiguous();
12091201

12101202
const py::ssize_t *src_strides_raw = src.get_strides_raw();
12111203
if (src_strides_raw == nullptr) {
@@ -1227,9 +1219,8 @@ tri(sycl::queue &exec_q,
12271219

12281220
shT dst_strides(src_nd);
12291221

1230-
int dst_flags = dst.get_flags();
1231-
bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0);
1232-
bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0);
1222+
bool is_dst_c_contig = dst.is_c_contiguous();
1223+
bool is_dst_f_contig = dst.is_f_contiguous();
12331224

12341225
const py::ssize_t *dst_strides_raw = dst.get_strides_raw();
12351226
if (dst_strides_raw == nullptr) {
@@ -1457,9 +1448,6 @@ PYBIND11_MODULE(_tensor_impl, m)
14571448
init_copy_for_reshape_dispatch_vector();
14581449
import_dpctl();
14591450

1460-
// populate types constants for type dispatching functions
1461-
array_types = dpctl::tensor::detail::usm_ndarray_types::get();
1462-
14631451
m.def(
14641452
"_contract_iter", &contract_iter,
14651453
"Simplifies iteration of array of given shape & stride. Returns "

0 commit comments

Comments
 (0)