Skip to content

Commit accb135

Browse files
committed
Update integer advanced indexing for array API 2024
array API 2024 spec requires advanced indexing with integer arrays, and when advanced indexing behavior is triggered, integral scalars are to be converted to arrays and broadcast with the other indices
1 parent 63f5129 commit accb135

File tree

4 files changed

+117
-57
lines changed

4 files changed

+117
-57
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
import builtins
1717
import operator
18+
from numbers import Integral
1819

1920
import numpy as np
2021

@@ -819,15 +820,26 @@ def _take_multi_index(ary, inds, p, mode=0):
819820
]
820821
if not isinstance(inds, (list, tuple)):
821822
inds = (inds,)
823+
any_usmarray = False
822824
for ind in inds:
823-
if not isinstance(ind, dpt.usm_ndarray):
824-
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
825-
queues_.append(ind.sycl_queue)
826-
usm_types_.append(ind.usm_type)
827-
if ind.dtype.kind not in "ui":
828-
raise IndexError(
829-
"arrays used as indices must be of integer (or boolean) type"
825+
if isinstance(ind, dpt.usm_ndarray):
826+
any_usmarray = True
827+
if ind.dtype.kind not in "ui":
828+
raise IndexError(
829+
"arrays used as indices must be of integer (or boolean) "
830+
"type"
831+
)
832+
queues_.append(ind.sycl_queue)
833+
usm_types_.append(ind.usm_type)
834+
elif not isinstance(ind, Integral):
835+
raise TypeError(
836+
"all elements of `ind` expected to be usm_ndarrays "
837+
"or integers"
830838
)
839+
if not any_usmarray:
840+
raise TypeError(
841+
"at least one element of `ind` expected to be a usm_ndarray"
842+
)
831843
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
832844
exec_q = dpctl.utils.get_execution_queue(queues_)
833845
if exec_q is None:
@@ -838,6 +850,18 @@ def _take_multi_index(ary, inds, p, mode=0):
838850
"be associated with the same queue."
839851
)
840852
if len(inds) > 1:
853+
inds = tuple(
854+
map(
855+
lambda ind: (
856+
ind
857+
if isinstance(ind, dpt.usm_ndarray)
858+
else dpt.asarray(
859+
ind, usm_type=res_usm_type, sycl_queue=exec_q
860+
)
861+
),
862+
inds,
863+
)
864+
)
841865
ind_dt = dpt.result_type(*inds)
842866
# ind arrays have been checked to be of integer dtype
843867
if ind_dt.kind not in "ui":
@@ -968,15 +992,26 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
968992
]
969993
if not isinstance(inds, (list, tuple)):
970994
inds = (inds,)
995+
any_usmarray = False
971996
for ind in inds:
972-
if not isinstance(ind, dpt.usm_ndarray):
973-
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
974-
queues_.append(ind.sycl_queue)
975-
usm_types_.append(ind.usm_type)
976-
if ind.dtype.kind not in "ui":
977-
raise IndexError(
978-
"arrays used as indices must be of integer (or boolean) type"
997+
if isinstance(ind, dpt.usm_ndarray):
998+
any_usmarray = True
999+
if ind.dtype.kind not in "ui":
1000+
raise IndexError(
1001+
"arrays used as indices must be of integer (or boolean) "
1002+
"type"
1003+
)
1004+
queues_.append(ind.sycl_queue)
1005+
usm_types_.append(ind.usm_type)
1006+
elif not isinstance(ind, Integral):
1007+
raise TypeError(
1008+
"all elements of `ind` expected to be usm_ndarrays "
1009+
"or integers"
9791010
)
1011+
if not any_usmarray:
1012+
raise TypeError(
1013+
"at least one element of `ind` expected to be a usm_ndarray"
1014+
)
9801015
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
9811016
exec_q = dpctl.utils.get_execution_queue(queues_)
9821017
if exec_q is not None:
@@ -994,6 +1029,18 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9941029
"be associated with the same queue."
9951030
)
9961031
if len(inds) > 1:
1032+
inds = tuple(
1033+
map(
1034+
lambda ind: (
1035+
ind
1036+
if isinstance(ind, dpt.usm_ndarray)
1037+
else dpt.asarray(
1038+
ind, usm_type=vals_usm_type, sycl_queue=exec_q
1039+
)
1040+
),
1041+
inds,
1042+
)
1043+
)
9971044
ind_dt = dpt.result_type(*inds)
9981045
# ind arrays have been checked to be of integer dtype
9991046
if ind_dt.kind not in "ui":

dpctl/tensor/_slicing.pxi

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ cdef Py_ssize_t _slice_len(
4646
cdef bint _is_integral(object x) except *:
4747
"""Gives True if x is an integral slice spec"""
4848
if isinstance(x, usm_ndarray):
49-
if x.ndim > 0:
50-
return False
51-
if x.dtype.kind not in "ui":
52-
return False
53-
return True
49+
return False
5450
if isinstance(x, bool):
5551
return False
5652
if isinstance(x, int):
@@ -179,10 +175,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
179175
if array_streak_started:
180176
array_streak_interrupted = True
181177
elif _is_integral(i):
182-
explicit_index += 1
183178
axes_referenced += 1
184179
if array_streak_started:
185-
array_streak_interrupted = True
180+
# integers converted to arrays in this case
181+
array_count += 1
182+
else:
183+
explicit_index += 1
186184
elif isinstance(i, usm_ndarray):
187185
if not seen_arrays_yet:
188186
seen_arrays_yet = True
@@ -196,7 +194,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
196194
dt_k = i.dtype.kind
197195
if dt_k == "b" and i.ndim > 0:
198196
axes_referenced += i.ndim
199-
elif dt_k in "ui" and i.ndim > 0:
197+
elif dt_k in "ui":
200198
axes_referenced += 1
201199
else:
202200
raise IndexError(
@@ -260,20 +258,28 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
260258
new_strides.append(0)
261259
elif _is_integral(ind_i):
262260
ind_i = ind_i.__index__()
263-
if 0 <= ind_i < shape[k]:
261+
if advanced_start_pos_set:
262+
# integers converted to arrays in this case
263+
new_advanced_ind.append(ind_i)
264264
k_new = k + 1
265-
if not is_empty:
266-
new_offset = new_offset + ind_i * strides[k]
267-
k = k_new
268-
elif -shape[k] <= ind_i < 0:
269-
k_new = k + 1
270-
if not is_empty:
271-
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
265+
new_shape.extend(shape[k:k_new])
266+
new_strides.extend(strides[k:k_new])
272267
k = k_new
273268
else:
274-
raise IndexError(
275-
("Index {0} is out of range for "
276-
"axes {1} with size {2}").format(ind_i, k, shape[k]))
269+
if 0 <= ind_i < shape[k]:
270+
k_new = k + 1
271+
if not is_empty:
272+
new_offset = new_offset + ind_i * strides[k]
273+
k = k_new
274+
elif -shape[k] <= ind_i < 0:
275+
k_new = k + 1
276+
if not is_empty:
277+
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
278+
k = k_new
279+
else:
280+
raise IndexError(
281+
("Index {0} is out of range for "
282+
"axes {1} with size {2}").format(ind_i, k, shape[k]))
277283
elif isinstance(ind_i, usm_ndarray):
278284
if not advanced_start_pos_set:
279285
new_advanced_start_pos = len(new_shape)

dpctl/tensor/_usmarray.pyx

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue)
161161
ev = self_queue.submit_barrier()
162162
stream.submit_barrier(dependent_events=[ev])
163163

164-
165164
cdef class usm_ndarray:
166165
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
167166
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -962,28 +961,30 @@ cdef class usm_ndarray:
962961
return res
963962

964963
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
965-
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
966-
key_ = adv_ind[0]
967-
adv_ind_end_p = key_.ndim + adv_ind_start_p
968-
if adv_ind_end_p > res.ndim:
969-
raise IndexError("too many indices for the array")
970-
key_shape = key_.shape
971-
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
972-
for i in range(key_.ndim):
973-
if matching:
974-
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
975-
matching = 0
976-
if not matching:
977-
raise IndexError("boolean index did not match indexed array in dimensions")
978-
res = _extract_impl(res, key_, axis=adv_ind_start_p)
979-
res.flags_ = _copy_writable(res.flags_, self.flags_)
980-
return res
981964

982-
if any(ind.dtype == dpt_bool for ind in adv_ind):
965+
# if len(adv_ind == 1), the (only) element is always an array
966+
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
967+
key_ = adv_ind[0]
968+
adv_ind_end_p = key_.ndim + adv_ind_start_p
969+
if adv_ind_end_p > res.ndim:
970+
raise IndexError("too many indices for the array")
971+
key_shape = key_.shape
972+
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
973+
for i in range(key_.ndim):
974+
if matching:
975+
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
976+
matching = 0
977+
if not matching:
978+
raise IndexError("boolean index did not match indexed array in dimensions")
979+
res = _extract_impl(res, key_, axis=adv_ind_start_p)
980+
res.flags_ = _copy_writable(res.flags_, self.flags_)
981+
return res
982+
983+
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
983984
adv_ind_int = list()
984985
for ind in adv_ind:
985-
if ind.dtype == dpt_bool:
986-
adv_ind_int.extend(_nonzero_impl(ind))
986+
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
987+
adv_ind_int.extend(_nonzero_impl(ind))
987988
else:
988989
adv_ind_int.append(ind)
989990
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
@@ -1433,10 +1434,10 @@ cdef class usm_ndarray:
14331434
_place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p)
14341435
return
14351436

1436-
if any(ind.dtype == dpt_bool for ind in adv_ind):
1437+
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
14371438
adv_ind_int = list()
14381439
for ind in adv_ind:
1439-
if ind.dtype == dpt_bool:
1440+
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
14401441
adv_ind_int.extend(_nonzero_impl(ind))
14411442
else:
14421443
adv_ind_int.append(ind)

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,14 @@ def test_advanced_slice5():
252252
q = get_queue_or_skip()
253253
ii = dpt.asarray([1, 2], sycl_queue=q)
254254
x = _make_3d("i4", q)
255-
with pytest.raises(IndexError):
256-
x[ii, 0, ii]
255+
y = x[ii, 0, ii]
256+
assert isinstance(y, dpt.usm_ndarray)
257+
# 0 broadcast to [0, 0] per array API
258+
assert y.shape == ii.shape
259+
assert _all_equal(
260+
(x[ii[i], 0, ii[i]] for i in range(ii.shape[0])),
261+
(y[i] for i in range(ii.shape[0])),
262+
)
257263

258264

259265
def test_advanced_slice6():

0 commit comments

Comments
 (0)