Skip to content

Commit 594785c

Browse files
committed
Add support for the out keyword in tensor.take
1 parent e7b2b1b commit 594785c

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_indexing_mode(name):
4040
)
4141

4242

43-
def take(x, indices, /, *, axis=None, mode="wrap"):
43+
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
4444
"""take(x, indices, axis=None, mode="wrap")
4545
4646
Takes elements from an array along a given axis at given indices.
@@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
5454
The axis along which the values will be selected.
5555
If ``x`` is one-dimensional, this argument is optional.
5656
Default: ``None``.
57+
out (Optional[usm_ndarray]):
58+
Output array to populate. Array must have the correct
59+
shape and the expected data type.
5760
mode (str, optional):
5861
How out-of-bounds indices will be handled. Possible values
5962
are:
@@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
121124
raise ValueError("`axis` must be 0 for an array of dimension 0.")
122125
res_shape = indices.shape
123126

124-
res = dpt.empty(
125-
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
126-
)
127+
dt = x.dtype
128+
129+
orig_out = out
130+
if out is not None:
131+
if not isinstance(out, dpt.usm_ndarray):
132+
raise TypeError(
133+
f"output array must be of usm_ndarray type, got {type(out)}"
134+
)
135+
if not out.flags.writable:
136+
raise ValueError("provided `out` array is read-only")
137+
138+
if out.shape != res_shape:
139+
raise ValueError(
140+
"The shape of input and output arrays are inconsistent. "
141+
f"Expected output shape is {res_shape}, got {out.shape}"
142+
)
143+
if dt != out.dtype:
144+
raise ValueError(
145+
f"Output array of type {dt} is needed, " f"got {out.dtype}"
146+
)
147+
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
148+
raise dpctl.utils.ExecutionPlacementError(
149+
"Input and output allocation queues are not compatible"
150+
)
151+
if ti._array_overlap(x, out):
152+
out = dpt.empty_like(out)
153+
else:
154+
out = dpt.empty(
155+
res_shape, dtype=dt, usm_type=res_usm_type, sycl_queue=exec_q
156+
)
127157

128158
_manager = dpctl.utils.SequentialOrderManager[exec_q]
129159
deps_ev = _manager.submitted_events
130160
hev, take_ev = ti._take(
131-
x, (indices,), res, axis, mode, sycl_queue=exec_q, depends=deps_ev
161+
x, (indices,), out, axis, mode, sycl_queue=exec_q, depends=deps_ev
132162
)
133163
_manager.add_event_pair(hev, take_ev)
134164

135-
return res
165+
if not (orig_out is None or out is orig_out):
166+
# Copy the out data from temporary buffer to original memory
167+
ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
168+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[take_ev]
169+
)
170+
_manager.add_event_pair(ht_e_cpy, cpy_ev)
171+
out = orig_out
172+
173+
return out
136174

137175

138176
def put(x, indices, vals, /, *, axis=None, mode="wrap"):

0 commit comments

Comments
 (0)