Skip to content

Commit 1aff08c

Browse files
committed
Integer advanced indexing now promotes indices arrays
1 parent 4b0ec1a commit 1aff08c

File tree

1 file changed

+51
-34
lines changed

1 file changed

+51
-34
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -766,22 +766,26 @@ def _take_multi_index(ary, inds, p):
766766
raise TypeError(
767767
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
768768
)
769+
ary_nd = ary.ndim
770+
p = normalize_axis_index(operator.index(p), ary_nd)
769771
queues_ = [
770772
ary.sycl_queue,
771773
]
772774
usm_types_ = [
773775
ary.usm_type,
774776
]
775-
if not isinstance(inds, list) and not isinstance(inds, tuple):
777+
if not isinstance(inds, (list, tuple)):
776778
inds = (inds,)
777-
all_integers = True
778779
for ind in inds:
779780
if not isinstance(ind, dpt.usm_ndarray):
780781
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
781782
queues_.append(ind.sycl_queue)
782783
usm_types_.append(ind.usm_type)
783-
if all_integers:
784-
all_integers = ind.dtype.kind in "ui"
784+
if ind.dtype.kind not in "ui":
785+
raise IndexError(
786+
"arrays used as indices must be of integer (or boolean) type"
787+
)
788+
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
785789
exec_q = dpctl.utils.get_execution_queue(queues_)
786790
if exec_q is None:
787791
raise dpctl.utils.ExecutionPlacementError(
@@ -790,30 +794,36 @@ def _take_multi_index(ary, inds, p):
790794
"Use `usm_ndarray.to_device` method to migrate data to "
791795
"be associated with the same queue."
792796
)
793-
if not all_integers:
794-
raise IndexError(
795-
"arrays used as indices must be of integer (or boolean) type"
796-
)
797797
if len(inds) > 1:
798+
ind_dt = dpt.result_type(*inds)
799+
# ind arrays have been checked to be of integer dtype
800+
if ind_dt.kind not in "ui":
801+
raise ValueError(
802+
"cannot safely promote indices to an integer data type"
803+
)
804+
inds = tuple(
805+
map(
806+
lambda ind: ind
807+
if ind.dtype == ind_dt
808+
else dpt.astype(ind, ind_dt),
809+
inds,
810+
)
811+
)
798812
inds = dpt.broadcast_arrays(*inds)
813+
ind0 = inds[0]
799814
ary_sh = ary.shape
800-
ary_nd = ary.ndim
801-
p = normalize_axis_index(operator.index(p), ary_nd)
802815
p_end = p + len(inds)
803-
inds_sz = inds[0].size
816+
inds_sz = ind0.size
804817
if 0 in ary_sh[p : p_end + 1] and inds_sz != 0:
805818
raise IndexError("cannot take non-empty indices from an empty axis")
806-
res_shape = ary_sh[:p] + inds[0].shape + ary_sh[p + len(inds) :]
807-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
819+
res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
808820
res = dpt.empty(
809821
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
810822
)
811-
812823
hev, _ = ti._take(
813824
src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q
814825
)
815826
hev.wait()
816-
817827
return res
818828

819829

@@ -881,6 +891,8 @@ def _put_multi_index(ary, inds, p, vals):
881891
raise TypeError(
882892
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
883893
)
894+
ary_nd = ary.ndim
895+
p = normalize_axis_index(operator.index(p), ary_nd)
884896
if isinstance(vals, dpt.usm_ndarray):
885897
queues_ = [ary.sycl_queue, vals.sycl_queue]
886898
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -891,22 +903,19 @@ def _put_multi_index(ary, inds, p, vals):
891903
usm_types_ = [
892904
ary.usm_type,
893905
]
894-
if not isinstance(inds, list) and not isinstance(inds, tuple):
906+
if not isinstance(inds, (list, tuple)):
895907
inds = (inds,)
896-
all_integers = True
897908
for ind in inds:
898909
if not isinstance(ind, dpt.usm_ndarray):
899910
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
900911
queues_.append(ind.sycl_queue)
901912
usm_types_.append(ind.usm_type)
902-
if all_integers:
903-
all_integers = ind.dtype.kind in "ui"
904-
if not all_integers:
905-
raise IndexError(
906-
"arrays used as indices must be of integer (or boolean) type"
907-
)
908-
exec_q = dpctl.utils.get_execution_queue(queues_)
913+
if ind.dtype.kind not in "ui":
914+
raise IndexError(
915+
"arrays used as indices must be of integer (or boolean) type"
916+
)
909917
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
918+
exec_q = dpctl.utils.get_execution_queue(queues_)
910919
if exec_q is not None:
911920
if not isinstance(vals, dpt.usm_ndarray):
912921
vals = dpt.asarray(
@@ -922,19 +931,28 @@ def _put_multi_index(ary, inds, p, vals):
922931
"be associated with the same queue."
923932
)
924933
if len(inds) > 1:
934+
ind_dt = dpt.result_type(*inds)
935+
# ind arrays have been checked to be of integer dtype
936+
if ind_dt.kind not in "ui":
937+
raise ValueError(
938+
"cannot safely promote indices to an integer data type"
939+
)
940+
inds = tuple(
941+
map(
942+
lambda ind: ind
943+
if ind.dtype == ind_dt
944+
else dpt.astype(ind, ind_dt),
945+
inds,
946+
)
947+
)
925948
inds = dpt.broadcast_arrays(*inds)
949+
ind0 = inds[0]
926950
ary_sh = ary.shape
927-
ary_nd = ary.ndim
928-
p = normalize_axis_index(operator.index(p), ary_nd)
929951
p_end = p + len(inds)
930-
inds_sz = inds[0].size
952+
inds_sz = ind0.size
931953
if 0 in ary_sh[p : p_end + 1] and inds_sz != 0:
932-
raise IndexError(
933-
"cannot put elements at non-empty indices along empty axis"
934-
)
935-
expected_vals_shape = (
936-
ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
937-
)
954+
raise IndexError("cannot put into non-empty indices from an empty axis")
955+
expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:]
938956
if vals.dtype == ary.dtype:
939957
rhs = vals
940958
else:
@@ -944,5 +962,4 @@ def _put_multi_index(ary, inds, p, vals):
944962
dst=ary, ind=inds, val=rhs, axis_start=p, mode=0, sycl_queue=exec_q
945963
)
946964
hev.wait()
947-
948965
return

0 commit comments

Comments
 (0)