41
41
import math
42
42
import operator
43
43
import warnings
44
- from collections import namedtuple
44
+ from typing import NamedTuple
45
45
46
+ import dpctl
46
47
import dpctl .tensor as dpt
47
48
import numpy
48
49
from dpctl .tensor ._numpy_helper import AxisError , normalize_axis_index
55
56
from .dpnp_utils import get_usm_allocations
56
57
from .dpnp_utils .dpnp_utils_pad import dpnp_pad
57
58
58
- Parameters = namedtuple (
59
- "Parameters_insert_delete" ,
60
- [
61
- "a" ,
62
- "a_ndim" ,
63
- "order" ,
64
- "axis" ,
65
- "slobj" ,
66
- "n" ,
67
- "a_shape" ,
68
- "exec_q" ,
69
- "usm_type" ,
70
- ],
71
- )
59
+
60
+ class InsertDeleteParams ( NamedTuple ):
61
+ """Parameters used for ``dpnp.delete`` and ``dpnp.insert``."""
62
+
63
+ a : dpnp_array
64
+ a_ndim : int
65
+ order : str
66
+ axis : int
67
+ slobj : list
68
+ n : int
69
+ a_shape : list
70
+ exec_q : dpctl . SyclQueue
71
+ usm_type : str
72
+
72
73
73
74
__all__ = [
74
75
"append" ,
@@ -139,7 +140,7 @@ def _check_stack_arrays(arrays):
139
140
def _delete_with_slice (params , obj , axis ):
140
141
"""Utility function for ``dpnp.delete`` when obj is slice."""
141
142
142
- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = params
143
+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = params
143
144
144
145
start , stop , step = obj .indices (n )
145
146
xr = range (start , stop , step )
@@ -154,14 +155,8 @@ def _delete_with_slice(params, obj, axis):
154
155
start = xr [- 1 ]
155
156
stop = xr [0 ] + 1
156
157
157
- a_shape [axis ] -= num_del
158
- new = dpnp .empty (
159
- a_shape ,
160
- dtype = a .dtype ,
161
- order = order ,
162
- sycl_queue = exec_q ,
163
- usm_type = usm_type ,
164
- )
158
+ newshape [axis ] -= num_del
159
+ new = dpnp .empty_like (a , order = order , shape = newshape )
165
160
# copy initial chunk
166
161
if start == 0 :
167
162
pass
@@ -200,7 +195,7 @@ def _delete_with_slice(params, obj, axis):
200
195
def _delete_without_slice (params , obj , axis , single_value ):
201
196
"""Utility function for ``dpnp.delete`` when obj is int or array of int."""
202
197
203
- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = params
198
+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = params
204
199
205
200
if single_value :
206
201
# optimization for a single value
@@ -211,14 +206,8 @@ def _delete_without_slice(params, obj, axis, single_value):
211
206
)
212
207
if obj < 0 :
213
208
obj += n
214
- a_shape [axis ] -= 1
215
- new = dpnp .empty (
216
- a_shape ,
217
- dtype = a .dtype ,
218
- order = order ,
219
- sycl_queue = exec_q ,
220
- usm_type = usm_type ,
221
- )
209
+ newshape [axis ] -= 1
210
+ new = dpnp .empty_like (a , order = order , shape = newshape )
222
211
slobj [axis ] = slice (None , obj )
223
212
new [tuple (slobj )] = a [tuple (slobj )]
224
213
slobj [axis ] = slice (obj , None )
@@ -264,18 +253,9 @@ def _calc_parameters(a, axis, obj, values=None):
264
253
n = a .shape [axis ]
265
254
a_shape = list (a .shape )
266
255
267
- if dpnp .is_supported_array_type (obj ) and dpnp .is_supported_array_type (
268
- values
269
- ):
270
- usm_type , exec_q = get_usm_allocations ([a , obj , values ])
271
- elif dpnp .is_supported_array_type (values ):
272
- usm_type , exec_q = get_usm_allocations ([a , values ])
273
- elif dpnp .is_supported_array_type (obj ):
274
- usm_type , exec_q = get_usm_allocations ([a , obj ])
275
- else :
276
- usm_type , exec_q = a .usm_type , a .sycl_queue
256
+ usm_type , exec_q = get_usm_allocations ([a , obj , values ])
277
257
278
- return Parameters (
258
+ return InsertDeleteParams (
279
259
a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type
280
260
)
281
261
@@ -287,7 +267,7 @@ def _insert_array_indices(parameters, indices, values, obj):
287
267
288
268
"""
289
269
290
- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = parameters
270
+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = parameters
291
271
292
272
is_array = isinstance (obj , (dpnp_array , numpy .ndarray , dpt .usm_ndarray ))
293
273
if indices .size == 0 and not is_array :
@@ -302,19 +282,13 @@ def _insert_array_indices(parameters, indices, values, obj):
302
282
numnew , dtype = indices .dtype , sycl_queue = exec_q , usm_type = usm_type
303
283
)
304
284
305
- a_shape [axis ] += numnew
285
+ newshape [axis ] += numnew
306
286
old_mask = dpnp .ones (
307
- a_shape [axis ], dtype = dpnp .bool , sycl_queue = exec_q , usm_type = usm_type
287
+ newshape [axis ], dtype = dpnp .bool , sycl_queue = exec_q , usm_type = usm_type
308
288
)
309
289
old_mask [indices ] = False
310
290
311
- new = dpnp .empty (
312
- a_shape ,
313
- dtype = a .dtype ,
314
- order = order ,
315
- sycl_queue = exec_q ,
316
- usm_type = usm_type ,
317
- )
291
+ new = dpnp .empty_like (a , order = order , shape = newshape )
318
292
slobj2 = [slice (None )] * a_ndim
319
293
slobj [axis ] = indices
320
294
slobj2 [axis ] = old_mask
@@ -331,24 +305,27 @@ def _insert_singleton_index(parameters, indices, values, obj):
331
305
332
306
"""
333
307
334
- a , a_ndim , order , axis , slobj , n , a_shape , exec_q , usm_type = parameters
308
+ a , a_ndim , order , axis , slobj , n , newshape , exec_q , usm_type = parameters
335
309
336
310
# In dpnp, `.item()` calls `.wait()`, so it is preferred to avoid it
337
311
# When possible (i.e. for numpy arrays, lists, etc), it is preferred
338
312
# to use `.item()` on a NumPy array
339
- if isinstance (obj , (slice , dpnp_array , dpt .usm_ndarray )):
313
+ if isinstance (obj , (dpnp_array , dpt .usm_ndarray )):
340
314
index = indices .item ()
341
315
else :
316
+ if isinstance (obj , slice ):
317
+ obj = numpy .arange (* obj .indices (n ), dtype = dpnp .intp )
342
318
index = numpy .asarray (obj ).item ()
343
319
344
320
if index < - n or index > n :
345
321
raise IndexError (
346
- f"index { index } is out of bounds for axis { axis } " f" with size { n } "
322
+ f"index { index } is out of bounds for axis { axis } with size { n } "
347
323
)
348
324
if index < 0 :
349
325
index += n
350
326
351
- # There are some object array corner cases here, that cannot be avoided
327
+ # Need to change the dtype of values to input array dtype and update
328
+ # its shape to make ``input_arr[..., index, ...] = values`` legal
352
329
values = dpnp .array (
353
330
values ,
354
331
copy = None ,
@@ -361,15 +338,11 @@ def _insert_singleton_index(parameters, indices, values, obj):
361
338
# numpy.insert behave differently if obj is an scalar or an array
362
339
# with one element, so, this change is needed to align with NumPy
363
340
values = dpnp .moveaxis (values , 0 , axis )
341
+
364
342
numnew = values .shape [axis ]
365
- a_shape [axis ] += numnew
366
- new = dpnp .empty (
367
- a_shape ,
368
- dtype = a .dtype ,
369
- order = order ,
370
- sycl_queue = exec_q ,
371
- usm_type = usm_type ,
372
- )
343
+ newshape [axis ] += numnew
344
+ new = dpnp .empty_like (a , order = order , shape = newshape )
345
+
373
346
slobj [axis ] = slice (None , index )
374
347
new [tuple (slobj )] = a [tuple (slobj )]
375
348
slobj [axis ] = slice (index , index + numnew )
@@ -2229,8 +2202,8 @@ def insert(arr, obj, values, axis=None):
2229
2202
)
2230
2203
else :
2231
2204
# need to copy obj, because indices will be changed in-place
2232
- indices = dpnp .array (
2233
- obj , copy = True , sycl_queue = params .exec_q , usm_type = params .usm_type
2205
+ indices = dpnp .copy (
2206
+ obj , sycl_queue = params .exec_q , usm_type = params .usm_type
2234
2207
)
2235
2208
if indices .dtype == dpnp .bool :
2236
2209
warnings .warn (
0 commit comments