Skip to content

Commit a81927d

Browse files
author
Diptorup Deb
committed
Fixes bug in boxing a DpnpNdArray from parent.
- When boxing a dpnp.ndarray using the reference of the parent stored during unboxing, there is a validation step on the strides. However, as a dpnp.ndarray object does not store any stride information when an array is unit strided the strides need to be calculated and validated against the strides in the Numba usmarraystruct object. The PR fixes the stride calculation that was previously done incorrectly.
1 parent cd5332f commit a81927d

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
901901
int ndim,
902902
PyArray_Descr *descr)
903903
{
904-
int i = 0, exp = 0;
904+
int i = 0, j = 0, k = 0, exp = 0;
905905
npy_intp *p = NULL;
906906
npy_intp *shape = NULL, *strides = NULL;
907907
PyObject *array = arystruct->parent;
@@ -933,6 +933,8 @@ static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
933933
shape = UsmNDArray_GetShape(arrayobj);
934934
strides = UsmNDArray_GetStrides(arrayobj);
935935

936+
// Ensure the shape of the array to be boxed matches the shape of the
937+
// original parent.
936938
for (i = 0; i < ndim; i++, p++) {
937939
if (shape[i] != *p)
938940
return NULL;
@@ -942,20 +944,52 @@ static PyObject *box_from_arystruct_parent(usmarystruct_t *arystruct,
942944
itemsize = arystruct->itemsize;
943945
while (itemsize >>= 1)
944946
exp++;
945-
// dpctl stores strides as number of elements and Numba stores strides as
947+
948+
// Ensure the strides of the array to be boxed matches the shape of the
949+
// original parent. Things to note:
950+
//
951+
// 1. dpctl only stores stride information if the array has a non-unit
952+
// stride. If the array is unit strided then dpctl does not populate the
953+
// stride attribute. To verify strides, we compute the strides from the
954+
// shape vector.
955+
//
956+
// 2. dpctl stores strides as number of elements and Numba stores strides as
946957
// bytes, for that reason we are multiplying stride by itemsize when
947-
// unboxing the external array.
958+
// unboxing the external array and dividing by itemsizwe when boxing the
959+
// array back.
960+
948961
if (strides) {
949-
if (strides[i] << exp != *p)
950-
return NULL;
962+
for (i = 0; i < ndim; ++i, ++p) {
963+
if (strides[i] << exp != *p) {
964+
DPEXRT_DEBUG(
965+
drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed "
966+
"from parent as strides in the "
967+
"arystruct are not the same as "
968+
"the strides in the parent object. "
969+
"Expected stride = %d actual stride = %d\n",
970+
strides[i] << exp, *p));
971+
return NULL;
972+
}
973+
}
951974
}
952975
else {
953-
for (i = 1; i < ndim; ++i, ++p) {
954-
if (shape[i] != *p)
976+
npy_intp tmp;
977+
for (i = (ndim * 2) - 1; i >= ndim; --i, ++p) {
978+
tmp = 1;
979+
for (j = i, k = ndim - 1; j > ndim; --j, --k)
980+
tmp *= shape[k];
981+
tmp <<= exp;
982+
if (tmp != *p) {
983+
DPEXRT_DEBUG(
984+
drt_debug_print("DPEXRT-DEBUG: Arrayobj cannot be boxed "
985+
"from parent as strides in the "
986+
"arystruct are not the same as "
987+
"the strides in the parent object. "
988+
"Expected stride = %d actual stride = %d\n",
989+
tmp, *p));
955990
return NULL;
991+
}
956992
}
957-
if (*p != 1)
958-
return NULL;
959993
}
960994

961995
// At the end of boxing our Meminfo destructor gets called and that will

numba_dpex/tests/core/types/DpnpNdArray/test_boxing_unboxing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def func(a):
3232
assert a.device == b.device
3333
assert a.strides == b.strides
3434
assert a.dtype == b.dtype
35+
# To ensure we are returning the original array when boxing
36+
assert id(a) == id(b)
3537

3638

3739
def test_stride_calc_at_unboxing():

0 commit comments

Comments
 (0)