Skip to content

Commit 9dbb0b6

Browse files
committed
put now casts vals when the data type differs from x
Fixes `take` and `put` being used on non-empty axes with non-empty indices Also adds a note to `put` about race conditions for non-unique indices
1 parent 1aff08c commit 9dbb0b6

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import operator
1818

19-
import numpy as np
2019
from numpy.core.numeric import normalize_axis_index
2120

2221
import dpctl
@@ -47,15 +46,15 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
4746
indices (usm_ndarray):
4847
One-dimensional array of indices.
4948
axis:
50-
The axis over which the values will be selected.
51-
If x is one-dimensional, this argument is optional.
52-
Default: `None`.
49+
The axis along which the values will be selected.
50+
If ``x`` is one-dimensional, this argument is optional.
51+
Default: ``None``.
5352
mode:
5453
How out-of-bounds indices will be handled.
55-
"wrap" - clamps indices to (-n <= i < n), then wraps
54+
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
5655
negative indices.
57-
"clip" - clips indices to (0 <= i < n)
58-
Default: `"wrap"`.
56+
``"clip"`` - clips indices to (0 <= i < n)
57+
Default: ``"wrap"``.
5958
6059
Returns:
6160
usm_ndarray:
@@ -73,7 +72,7 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
7372
type(indices)
7473
)
7574
)
76-
if not np.issubdtype(indices.dtype, np.integer):
75+
if indices.dtype.kind not in "ui":
7776
raise IndexError(
7877
"`indices` expected integer data type, got `{}`".format(
7978
indices.dtype
@@ -104,6 +103,9 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
104103

105104
if x_ndim > 0:
106105
axis = normalize_axis_index(operator.index(axis), x_ndim)
106+
x_sh = x.shape
107+
if x_sh[axis] == 0 and indices.size != 0:
108+
raise IndexError("cannot take non-empty indices from an empty axis")
107109
res_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
108110
else:
109111
if axis != 0:
@@ -130,19 +132,26 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
130132
The array the values will be put into.
131133
indices (usm_ndarray)
132134
One-dimensional array of indices.
135+
136+
Note that if indices are not unique, a race
137+
condition will result, and the value written to
138+
``x`` will not be deterministic.
139+
:py:func:`dpctl.tensor.unique` can be used to
140+
guarantee unique elements in ``indices``.
133141
vals:
134-
Array of values to be put into `x`.
135-
Must be broadcastable to the shape of `indices`.
142+
Array of values to be put into ``x``.
143+
Must be broadcastable to the result shape
144+
``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
136145
axis:
137-
The axis over which the values will be placed.
138-
If x is one-dimensional, this argument is optional.
139-
Default: `None`.
146+
The axis along which the values will be placed.
147+
If ``x`` is one-dimensional, this argument is optional.
148+
Default: ``None``.
140149
mode:
141150
How out-of-bounds indices will be handled.
142-
"wrap" - clamps indices to (-n <= i < n), then wraps
151+
``"wrap"`` - clamps indices to (-n <= i < n), then wraps
143152
negative indices.
144-
"clip" - clips indices to (0 <= i < n)
145-
Default: `"wrap"`.
153+
``"clip"`` - clips indices to (0 <= i < n)
154+
Default: ``"wrap"``.
146155
"""
147156
if not isinstance(x, dpt.usm_ndarray):
148157
raise TypeError(
@@ -168,7 +177,7 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
168177
raise ValueError(
169178
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
170179
)
171-
if not np.issubdtype(indices.dtype, np.integer):
180+
if indices.dtype.kind not in "ui":
172181
raise IndexError(
173182
"`indices` expected integer data type, got `{}`".format(
174183
indices.dtype
@@ -195,7 +204,9 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
195204

196205
if x_ndim > 0:
197206
axis = normalize_axis_index(operator.index(axis), x_ndim)
198-
207+
x_sh = x.shape
208+
if x_sh[axis] == 0 and indices.size != 0:
209+
raise IndexError("cannot take non-empty indices from an empty axis")
199210
val_shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
200211
else:
201212
if axis != 0:
@@ -206,10 +217,18 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
206217
vals = dpt.asarray(
207218
vals, dtype=x.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
208219
)
220+
# choose to throw here for consistency with `place`
221+
if vals.size == 0:
222+
raise ValueError(
223+
"cannot put into non-empty indices along an empty axis"
224+
)
225+
if vals.dtype == x.dtype:
226+
rhs = vals
227+
else:
228+
rhs = dpt.astype(vals, x.dtype)
229+
rhs = dpt.broadcast_to(rhs, val_shape)
209230

210-
vals = dpt.broadcast_to(vals, val_shape)
211-
212-
hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
231+
hev, _ = ti._put(x, (indices,), rhs, axis, mode, sycl_queue=exec_q)
213232
hev.wait()
214233

215234

0 commit comments

Comments
 (0)