Skip to content

Commit 4b0ec1a

Browse files
committed
Tweaks to advanced integer indexing
Setting items in an array now casts the right-hand side to the array data type when the data types differ Setting and getting from an empty axis with non-empty indices now throws `IndexError`
1 parent f5c6610 commit 4b0ec1a

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,9 @@ def _nonzero_impl(ary):
763763

764764
def _take_multi_index(ary, inds, p):
765765
if not isinstance(ary, dpt.usm_ndarray):
766-
raise TypeError
766+
raise TypeError(
767+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
768+
)
767769
queues_ = [
768770
ary.sycl_queue,
769771
]
@@ -774,23 +776,34 @@ def _take_multi_index(ary, inds, p):
774776
inds = (inds,)
775777
all_integers = True
776778
for ind in inds:
779+
if not isinstance(ind, dpt.usm_ndarray):
780+
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
777781
queues_.append(ind.sycl_queue)
778782
usm_types_.append(ind.usm_type)
779783
if all_integers:
780784
all_integers = ind.dtype.kind in "ui"
781785
exec_q = dpctl.utils.get_execution_queue(queues_)
782786
if exec_q is None:
783-
raise dpctl.utils.ExecutionPlacementError("")
787+
raise dpctl.utils.ExecutionPlacementError(
788+
"Can not automatically determine where to allocate the "
789+
"result or performance execution. "
790+
"Use `usm_ndarray.to_device` method to migrate data to "
791+
"be associated with the same queue."
792+
)
784793
if not all_integers:
785794
raise IndexError(
786795
"arrays used as indices must be of integer (or boolean) type"
787796
)
788797
if len(inds) > 1:
789798
inds = dpt.broadcast_arrays(*inds)
790-
ary_ndim = ary.ndim
791-
p = normalize_axis_index(operator.index(p), ary_ndim)
792-
793-
res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
799+
ary_sh = ary.shape
800+
ary_nd = ary.ndim
801+
p = normalize_axis_index(operator.index(p), ary_nd)
802+
p_end = p + len(inds)
803+
inds_sz = inds[0].size
804+
if 0 in ary_sh[p : p_end + 1] and inds_sz != 0:
805+
raise IndexError("cannot take non-empty indices from an empty axis")
806+
res_shape = ary_sh[:p] + inds[0].shape + ary_sh[p + len(inds) :]
794807
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
795808
res = dpt.empty(
796809
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
@@ -864,6 +877,10 @@ def _place_impl(ary, ary_mask, vals, axis=0):
864877

865878

866879
def _put_multi_index(ary, inds, p, vals):
880+
if not isinstance(ary, dpt.usm_ndarray):
881+
raise TypeError(
882+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
883+
)
867884
if isinstance(vals, dpt.usm_ndarray):
868885
queues_ = [ary.sycl_queue, vals.sycl_queue]
869886
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -879,40 +896,52 @@ def _put_multi_index(ary, inds, p, vals):
879896
all_integers = True
880897
for ind in inds:
881898
if not isinstance(ind, dpt.usm_ndarray):
882-
raise TypeError
899+
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
883900
queues_.append(ind.sycl_queue)
884901
usm_types_.append(ind.usm_type)
885902
if all_integers:
886903
all_integers = ind.dtype.kind in "ui"
904+
if not all_integers:
905+
raise IndexError(
906+
"arrays used as indices must be of integer (or boolean) type"
907+
)
887908
exec_q = dpctl.utils.get_execution_queue(queues_)
909+
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
910+
if exec_q is not None:
911+
if not isinstance(vals, dpt.usm_ndarray):
912+
vals = dpt.asarray(
913+
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
914+
)
915+
else:
916+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
888917
if exec_q is None:
889918
raise dpctl.utils.ExecutionPlacementError(
890919
"Can not automatically determine where to allocate the "
891920
"result or performance execution. "
892921
"Use `usm_ndarray.to_device` method to migrate data to "
893922
"be associated with the same queue."
894923
)
895-
if not all_integers:
896-
raise IndexError(
897-
"arrays used as indices must be of integer (or boolean) type"
898-
)
899924
if len(inds) > 1:
900925
inds = dpt.broadcast_arrays(*inds)
901-
ary_ndim = ary.ndim
902-
903-
p = normalize_axis_index(operator.index(p), ary_ndim)
904-
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
905-
906-
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
907-
if not isinstance(vals, dpt.usm_ndarray):
908-
vals = dpt.asarray(
909-
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
926+
ary_sh = ary.shape
927+
ary_nd = ary.ndim
928+
p = normalize_axis_index(operator.index(p), ary_nd)
929+
p_end = p + len(inds)
930+
inds_sz = inds[0].size
931+
if 0 in ary_sh[p : p_end + 1] and inds_sz != 0:
932+
raise IndexError(
933+
"cannot put elements at non-empty indices along empty axis"
910934
)
911-
912-
vals = dpt.broadcast_to(vals, vals_shape)
913-
935+
expected_vals_shape = (
936+
ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
937+
)
938+
if vals.dtype == ary.dtype:
939+
rhs = vals
940+
else:
941+
rhs = dpt.astype(vals, ary.dtype)
942+
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
914943
hev, _ = ti._put(
915-
dst=ary, ind=inds, val=vals, axis_start=p, mode=0, sycl_queue=exec_q
944+
dst=ary, ind=inds, val=rhs, axis_start=p, mode=0, sycl_queue=exec_q
916945
)
917946
hev.wait()
918947

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ usm_ndarray_take(const dpctl::tensor::usm_ndarray &src,
414414
ind_offsets.push_back(py::ssize_t(0));
415415
}
416416

417+
if (ind_nelems == 0) {
418+
return std::make_pair(sycl::event{}, sycl::event{});
419+
}
420+
417421
char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);
418422

419423
if (packed_ind_ptrs == nullptr) {
@@ -717,6 +721,10 @@ usm_ndarray_put(const dpctl::tensor::usm_ndarray &dst,
717721
ind_offsets.push_back(py::ssize_t(0));
718722
}
719723

724+
if (ind_nelems == 0) {
725+
return std::make_pair(sycl::event{}, sycl::event{});
726+
}
727+
720728
char **packed_ind_ptrs = sycl::malloc_device<char *>(k, exec_q);
721729

722730
if (packed_ind_ptrs == nullptr) {

0 commit comments

Comments
 (0)