Skip to content

Commit 4e6a219

Browse files
committed
address comments
1 parent 360f15a commit 4e6a219

File tree

2 files changed

+225
-248
lines changed

2 files changed

+225
-248
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
import math
4242
import operator
4343
import warnings
44-
from collections import namedtuple
44+
from typing import NamedTuple
4545

46+
import dpctl
4647
import dpctl.tensor as dpt
4748
import numpy
4849
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_index
@@ -55,20 +56,20 @@
5556
from .dpnp_utils import get_usm_allocations
5657
from .dpnp_utils.dpnp_utils_pad import dpnp_pad
5758

58-
Parameters = namedtuple(
59-
"Parameters_insert_delete",
60-
[
61-
"a",
62-
"a_ndim",
63-
"order",
64-
"axis",
65-
"slobj",
66-
"n",
67-
"a_shape",
68-
"exec_q",
69-
"usm_type",
70-
],
71-
)
59+
60+
class InsertDeleteParams(NamedTuple):
61+
"""Parameters used for ``dpnp.delete`` and ``dpnp.insert``."""
62+
63+
a: dpnp_array
64+
a_ndim: int
65+
order: str
66+
axis: int
67+
slobj: list
68+
n: int
69+
a_shape: list
70+
exec_q: dpctl.SyclQueue
71+
usm_type: str
72+
7273

7374
__all__ = [
7475
"append",
@@ -139,7 +140,7 @@ def _check_stack_arrays(arrays):
139140
def _delete_with_slice(params, obj, axis):
140141
"""Utility function for ``dpnp.delete`` when obj is slice."""
141142

142-
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = params
143+
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = params
143144

144145
start, stop, step = obj.indices(n)
145146
xr = range(start, stop, step)
@@ -154,14 +155,8 @@ def _delete_with_slice(params, obj, axis):
154155
start = xr[-1]
155156
stop = xr[0] + 1
156157

157-
a_shape[axis] -= num_del
158-
new = dpnp.empty(
159-
a_shape,
160-
dtype=a.dtype,
161-
order=order,
162-
sycl_queue=exec_q,
163-
usm_type=usm_type,
164-
)
158+
newshape[axis] -= num_del
159+
new = dpnp.empty_like(a, order=order, shape=newshape)
165160
# copy initial chunk
166161
if start == 0:
167162
pass
@@ -200,7 +195,7 @@ def _delete_with_slice(params, obj, axis):
200195
def _delete_without_slice(params, obj, axis, single_value):
201196
"""Utility function for ``dpnp.delete`` when obj is int or array of int."""
202197

203-
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = params
198+
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = params
204199

205200
if single_value:
206201
# optimization for a single value
@@ -211,14 +206,8 @@ def _delete_without_slice(params, obj, axis, single_value):
211206
)
212207
if obj < 0:
213208
obj += n
214-
a_shape[axis] -= 1
215-
new = dpnp.empty(
216-
a_shape,
217-
dtype=a.dtype,
218-
order=order,
219-
sycl_queue=exec_q,
220-
usm_type=usm_type,
221-
)
209+
newshape[axis] -= 1
210+
new = dpnp.empty_like(a, order=order, shape=newshape)
222211
slobj[axis] = slice(None, obj)
223212
new[tuple(slobj)] = a[tuple(slobj)]
224213
slobj[axis] = slice(obj, None)
@@ -264,18 +253,9 @@ def _calc_parameters(a, axis, obj, values=None):
264253
n = a.shape[axis]
265254
a_shape = list(a.shape)
266255

267-
if dpnp.is_supported_array_type(obj) and dpnp.is_supported_array_type(
268-
values
269-
):
270-
usm_type, exec_q = get_usm_allocations([a, obj, values])
271-
elif dpnp.is_supported_array_type(values):
272-
usm_type, exec_q = get_usm_allocations([a, values])
273-
elif dpnp.is_supported_array_type(obj):
274-
usm_type, exec_q = get_usm_allocations([a, obj])
275-
else:
276-
usm_type, exec_q = a.usm_type, a.sycl_queue
256+
usm_type, exec_q = get_usm_allocations([a, obj, values])
277257

278-
return Parameters(
258+
return InsertDeleteParams(
279259
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type
280260
)
281261

@@ -287,7 +267,7 @@ def _insert_array_indices(parameters, indices, values, obj):
287267
288268
"""
289269

290-
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = parameters
270+
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = parameters
291271

292272
is_array = isinstance(obj, (dpnp_array, numpy.ndarray, dpt.usm_ndarray))
293273
if indices.size == 0 and not is_array:
@@ -302,19 +282,13 @@ def _insert_array_indices(parameters, indices, values, obj):
302282
numnew, dtype=indices.dtype, sycl_queue=exec_q, usm_type=usm_type
303283
)
304284

305-
a_shape[axis] += numnew
285+
newshape[axis] += numnew
306286
old_mask = dpnp.ones(
307-
a_shape[axis], dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
287+
newshape[axis], dtype=dpnp.bool, sycl_queue=exec_q, usm_type=usm_type
308288
)
309289
old_mask[indices] = False
310290

311-
new = dpnp.empty(
312-
a_shape,
313-
dtype=a.dtype,
314-
order=order,
315-
sycl_queue=exec_q,
316-
usm_type=usm_type,
317-
)
291+
new = dpnp.empty_like(a, order=order, shape=newshape)
318292
slobj2 = [slice(None)] * a_ndim
319293
slobj[axis] = indices
320294
slobj2[axis] = old_mask
@@ -331,24 +305,27 @@ def _insert_singleton_index(parameters, indices, values, obj):
331305
332306
"""
333307

334-
a, a_ndim, order, axis, slobj, n, a_shape, exec_q, usm_type = parameters
308+
a, a_ndim, order, axis, slobj, n, newshape, exec_q, usm_type = parameters
335309

336310
# In dpnp, `.item()` calls `.wait()`, so it is preferred to avoid it
337311
# When possible (i.e. for numpy arrays, lists, etc), it is preferred
338312
# to use `.item()` on a NumPy array
339-
if isinstance(obj, (slice, dpnp_array, dpt.usm_ndarray)):
313+
if isinstance(obj, (dpnp_array, dpt.usm_ndarray)):
340314
index = indices.item()
341315
else:
316+
if isinstance(obj, slice):
317+
obj = numpy.arange(*obj.indices(n), dtype=dpnp.intp)
342318
index = numpy.asarray(obj).item()
343319

344320
if index < -n or index > n:
345321
raise IndexError(
346-
f"index {index} is out of bounds for axis {axis} " f"with size {n}"
322+
f"index {index} is out of bounds for axis {axis} with size {n}"
347323
)
348324
if index < 0:
349325
index += n
350326

351-
# There are some object array corner cases here, that cannot be avoided
327+
# Need to change the dtype of values to input array dtype and update
328+
# its shape to make ``input_arr[..., index, ...] = values`` legal
352329
values = dpnp.array(
353330
values,
354331
copy=None,
@@ -361,15 +338,11 @@ def _insert_singleton_index(parameters, indices, values, obj):
361338
# numpy.insert behave differently if obj is an scalar or an array
362339
# with one element, so, this change is needed to align with NumPy
363340
values = dpnp.moveaxis(values, 0, axis)
341+
364342
numnew = values.shape[axis]
365-
a_shape[axis] += numnew
366-
new = dpnp.empty(
367-
a_shape,
368-
dtype=a.dtype,
369-
order=order,
370-
sycl_queue=exec_q,
371-
usm_type=usm_type,
372-
)
343+
newshape[axis] += numnew
344+
new = dpnp.empty_like(a, order=order, shape=newshape)
345+
373346
slobj[axis] = slice(None, index)
374347
new[tuple(slobj)] = a[tuple(slobj)]
375348
slobj[axis] = slice(index, index + numnew)
@@ -2229,8 +2202,8 @@ def insert(arr, obj, values, axis=None):
22292202
)
22302203
else:
22312204
# need to copy obj, because indices will be changed in-place
2232-
indices = dpnp.array(
2233-
obj, copy=True, sycl_queue=params.exec_q, usm_type=params.usm_type
2205+
indices = dpnp.copy(
2206+
obj, sycl_queue=params.exec_q, usm_type=params.usm_type
22342207
)
22352208
if indices.dtype == dpnp.bool:
22362209
warnings.warn(

0 commit comments

Comments
 (0)