16
16
17
17
import operator
18
18
19
- import numpy as np
20
19
from numpy .core .numeric import normalize_axis_index
21
20
22
21
import dpctl
@@ -47,15 +46,15 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
47
46
indices (usm_ndarray):
48
47
One-dimensional array of indices.
49
48
axis:
50
- The axis over which the values will be selected.
51
- If x is one-dimensional, this argument is optional.
52
- Default: `None`.
49
+ The axis along which the values will be selected.
50
+ If ``x`` is one-dimensional, this argument is optional.
51
+ Default: `` None` `.
53
52
mode:
54
53
How out-of-bounds indices will be handled.
55
- "wrap" - clamps indices to (-n <= i < n), then wraps
54
+ `` "wrap"`` - clamps indices to (-n <= i < n), then wraps
56
55
negative indices.
57
- "clip" - clips indices to (0 <= i < n)
58
- Default: `"wrap"`.
56
+ `` "clip"`` - clips indices to (0 <= i < n)
57
+ Default: `` "wrap"` `.
59
58
60
59
Returns:
61
60
usm_ndarray:
@@ -73,7 +72,7 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
73
72
type (indices )
74
73
)
75
74
)
76
- if not np . issubdtype ( indices .dtype , np . integer ) :
75
+ if indices .dtype . kind not in "ui" :
77
76
raise IndexError (
78
77
"`indices` expected integer data type, got `{}`" .format (
79
78
indices .dtype
@@ -104,6 +103,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
104
103
105
104
if x_ndim > 0 :
106
105
axis = normalize_axis_index (operator .index (axis ), x_ndim )
106
+ x_sh = x .shape
107
+ if x_sh [axis ] == 0 and indices .size != 0 :
108
+ raise IndexError ("cannot take non-empty indices from an empty axis" )
107
109
res_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
108
110
else :
109
111
if axis != 0 :
@@ -130,19 +132,26 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
130
132
The array the values will be put into.
131
133
indices (usm_ndarray)
132
134
One-dimensional array of indices.
135
+
136
+ Note that if indices are not unique, a race
137
+ condition will result, and the value written to
138
+ ``x`` will not be deterministic.
139
+ :py:func:`dpctl.tensor.unique` can be used to
140
+ guarantee unique elements in ``indices``.
133
141
vals:
134
- Array of values to be put into `x`.
135
- Must be broadcastable to the shape of `indices`.
142
+ Array of values to be put into ``x``.
143
+ Must be broadcastable to the result shape
144
+ ``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
136
145
axis:
137
- The axis over which the values will be placed.
138
- If x is one-dimensional, this argument is optional.
139
- Default: `None`.
146
+ The axis along which the values will be placed.
147
+ If ``x`` is one-dimensional, this argument is optional.
148
+ Default: `` None` `.
140
149
mode:
141
150
How out-of-bounds indices will be handled.
142
- "wrap" - clamps indices to (-n <= i < n), then wraps
151
+ `` "wrap"`` - clamps indices to (-n <= i < n), then wraps
143
152
negative indices.
144
- "clip" - clips indices to (0 <= i < n)
145
- Default: `"wrap"`.
153
+ `` "clip"`` - clips indices to (0 <= i < n)
154
+ Default: `` "wrap"` `.
146
155
"""
147
156
if not isinstance (x , dpt .usm_ndarray ):
148
157
raise TypeError (
@@ -168,7 +177,7 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
168
177
raise ValueError (
169
178
"`indices` expected a 1D array, got `{}`" .format (indices .ndim )
170
179
)
171
- if not np . issubdtype ( indices .dtype , np . integer ) :
180
+ if indices .dtype . kind not in "ui" :
172
181
raise IndexError (
173
182
"`indices` expected integer data type, got `{}`" .format (
174
183
indices .dtype
@@ -195,7 +204,9 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
195
204
196
205
if x_ndim > 0 :
197
206
axis = normalize_axis_index (operator .index (axis ), x_ndim )
198
-
207
+ x_sh = x .shape
208
+ if x_sh [axis ] == 0 and indices .size != 0 :
209
+ raise IndexError ("cannot take non-empty indices from an empty axis" )
199
210
val_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
200
211
else :
201
212
if axis != 0 :
@@ -206,10 +217,18 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
206
217
vals = dpt .asarray (
207
218
vals , dtype = x .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
208
219
)
220
+ # choose to throw here for consistency with `place`
221
+ if vals .size == 0 :
222
+ raise ValueError (
223
+ "cannot put into non-empty indices along an empty axis"
224
+ )
225
+ if vals .dtype == x .dtype :
226
+ rhs = vals
227
+ else :
228
+ rhs = dpt .astype (vals , x .dtype )
229
+ rhs = dpt .broadcast_to (rhs , val_shape )
209
230
210
- vals = dpt .broadcast_to (vals , val_shape )
211
-
212
- hev , _ = ti ._put (x , (indices ,), vals , axis , mode , sycl_queue = exec_q )
231
+ hev , _ = ti ._put (x , (indices ,), rhs , axis , mode , sycl_queue = exec_q )
213
232
hev .wait ()
214
233
215
234
0 commit comments