@@ -40,7 +40,7 @@ def _get_indexing_mode(name):
40
40
)
41
41
42
42
43
- def take (x , indices , / , * , axis = None , mode = "wrap" ):
43
+ def take (x , indices , / , * , axis = None , out = None , mode = "wrap" ):
44
44
"""take(x, indices, axis=None, mode="wrap")
45
45
46
46
Takes elements from an array along a given axis at given indices.
@@ -54,6 +54,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
54
54
The axis along which the values will be selected.
55
55
If ``x`` is one-dimensional, this argument is optional.
56
56
Default: ``None``.
57
+ out (Optional[usm_ndarray]):
58
+ Output array to populate. Array must have the correct
59
+ shape and the expected data type.
57
60
mode (str, optional):
58
61
How out-of-bounds indices will be handled. Possible values
59
62
are:
@@ -121,18 +124,53 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
121
124
raise ValueError ("`axis` must be 0 for an array of dimension 0." )
122
125
res_shape = indices .shape
123
126
124
- res = dpt .empty (
125
- res_shape , dtype = x .dtype , usm_type = res_usm_type , sycl_queue = exec_q
126
- )
127
+ dt = x .dtype
128
+
129
+ orig_out = out
130
+ if out is not None :
131
+ if not isinstance (out , dpt .usm_ndarray ):
132
+ raise TypeError (
133
+ f"output array must be of usm_ndarray type, got { type (out )} "
134
+ )
135
+ if not out .flags .writable :
136
+ raise ValueError ("provided `out` array is read-only" )
137
+
138
+ if out .shape != res_shape :
139
+ raise ValueError (
140
+ "The shape of input and output arrays are inconsistent. "
141
+ f"Expected output shape is { res_shape } , got { out .shape } "
142
+ )
143
+ if dt != out .dtype :
144
+ raise ValueError (
145
+ f"Output array of type { dt } is needed, " f"got { out .dtype } "
146
+ )
147
+ if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
148
+ raise dpctl .utils .ExecutionPlacementError (
149
+ "Input and output allocation queues are not compatible"
150
+ )
151
+ if ti ._array_overlap (x , out ):
152
+ out = dpt .empty_like (out )
153
+ else :
154
+ out = dpt .empty (
155
+ res_shape , dtype = dt , usm_type = res_usm_type , sycl_queue = exec_q
156
+ )
127
157
128
158
_manager = dpctl .utils .SequentialOrderManager [exec_q ]
129
159
deps_ev = _manager .submitted_events
130
160
hev , take_ev = ti ._take (
131
- x , (indices ,), res , axis , mode , sycl_queue = exec_q , depends = deps_ev
161
+ x , (indices ,), out , axis , mode , sycl_queue = exec_q , depends = deps_ev
132
162
)
133
163
_manager .add_event_pair (hev , take_ev )
134
164
135
- return res
165
+ if not (orig_out is None or out is orig_out ):
166
+ # Copy the out data from temporary buffer to original memory
167
+ ht_e_cpy , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
168
+ src = out , dst = orig_out , sycl_queue = exec_q , depends = [take_ev ]
169
+ )
170
+ _manager .add_event_pair (ht_e_cpy , cpy_ev )
171
+ out = orig_out
172
+
173
+ return out
136
174
137
175
138
176
def put (x , indices , vals , / , * , axis = None , mode = "wrap" ):
0 commit comments