Skip to content

Commit abeb59e

Browse files
Add mode keyword to _take_multi_index with default 0
1 parent 132d55d commit abeb59e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,13 +795,18 @@ def _nonzero_impl(ary):
795795
return res
796796

797797

798-
def _take_multi_index(ary, inds, p):
798+
def _take_multi_index(ary, inds, p, mode=0):
799799
if not isinstance(ary, dpt.usm_ndarray):
800800
raise TypeError(
801801
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
802802
)
803803
ary_nd = ary.ndim
804804
p = normalize_axis_index(operator.index(p), ary_nd)
805+
mode = operator.index(mode)
806+
if mode not in [0, 1]:
807+
raise ValueError(
808+
"Invalid value for mode keyword, only 0 or 1 is supported"
809+
)
805810
queues_ = [
806811
ary.sycl_queue,
807812
]
@@ -860,7 +865,7 @@ def _take_multi_index(ary, inds, p):
860865
ind=inds,
861866
dst=res,
862867
axis_start=p,
863-
mode=0,
868+
mode=mode,
864869
sycl_queue=exec_q,
865870
depends=dep_ev,
866871
)

0 commit comments

Comments
 (0)