Skip to content

Commit b89de2a

Browse files
Modularized retrieval of usm_ndarray_ptr
Use -O1 when compiling tensor_py for now to work around suspected issue with loading of C-API functions.
1 parent a496d53 commit b89de2a

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,7 @@ class usm_ndarray : public py::object
439439

440440
char *get_data() const
441441
{
442-
PyObject *raw_o = this->ptr();
443-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
442+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
444443

445444
return UsmNDArray_GetData(raw_ar);
446445
}
@@ -452,16 +451,14 @@ class usm_ndarray : public py::object
452451

453452
int get_ndim() const
454453
{
455-
PyObject *raw_o = this->ptr();
456-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
454+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
457455

458456
return UsmNDArray_GetNDim(raw_ar);
459457
}
460458

461459
const py::ssize_t *get_shape_raw() const
462460
{
463-
PyObject *raw_o = this->ptr();
464-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
461+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
465462

466463
return UsmNDArray_GetShape(raw_ar);
467464
}
@@ -474,16 +471,14 @@ class usm_ndarray : public py::object
474471

475472
const py::ssize_t *get_strides_raw() const
476473
{
477-
PyObject *raw_o = this->ptr();
478-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
474+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
479475

480476
return UsmNDArray_GetStrides(raw_ar);
481477
}
482478

483479
py::ssize_t get_size() const
484480
{
485-
PyObject *raw_o = this->ptr();
486-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
481+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
487482

488483
int ndim = UsmNDArray_GetNDim(raw_ar);
489484
const py::ssize_t *shape = UsmNDArray_GetShape(raw_ar);
@@ -499,8 +494,7 @@ class usm_ndarray : public py::object
499494

500495
std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
501496
{
502-
PyObject *raw_o = this->ptr();
503-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
497+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
504498

505499
int nd = UsmNDArray_GetNDim(raw_ar);
506500
const py::ssize_t *shape = UsmNDArray_GetShape(raw_ar);
@@ -533,33 +527,29 @@ class usm_ndarray : public py::object
533527

534528
sycl::queue get_queue() const
535529
{
536-
PyObject *raw_o = this->ptr();
537-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
530+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
538531

539532
DPCTLSyclQueueRef QRef = UsmNDArray_GetQueueRef(raw_ar);
540533
return *(reinterpret_cast<sycl::queue *>(QRef));
541534
}
542535

543536
int get_typenum() const
544537
{
545-
PyObject *raw_o = this->ptr();
546-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
538+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
547539

548540
return UsmNDArray_GetTypenum(raw_ar);
549541
}
550542

551543
int get_flags() const
552544
{
553-
PyObject *raw_o = this->ptr();
554-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
545+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
555546

556547
return UsmNDArray_GetFlags(raw_ar);
557548
}
558549

559550
int get_elemsize() const
560551
{
561-
PyObject *raw_o = this->ptr();
562-
PyUSMArrayObject *raw_ar = reinterpret_cast<PyUSMArrayObject *>(raw_o);
552+
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
563553

564554
return UsmNDArray_GetElementSize(raw_ar);
565555
}
@@ -575,6 +565,12 @@ class usm_ndarray : public py::object
575565
int flags = this->get_flags();
576566
return static_cast<bool>(flags & USM_ARRAY_F_CONTIGUOUS);
577567
}
568+
569+
private:
570+
PyUSMArrayObject *usm_array_ptr() const
571+
{
572+
return reinterpret_cast<PyUSMArrayObject *>(m_ptr);
573+
}
578574
};
579575

580576
} // end namespace tensor

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ set(python_module_name _tensor_impl)
1919
pybind11_add_module(${python_module_name} MODULE
2020
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
2121
)
22+
# FIXME: remove this once issue with dpctl_capi loading is fixed
23+
target_compile_options(${python_module_name} PRIVATE -O1)
2224
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
2325
target_include_directories(${python_module_name}
2426
PRIVATE

0 commit comments

Comments
 (0)