Skip to content

Commit da44614

Browse files
committed
Handle boolean arrays in dpnp.insert as mask
1 parent eaa79fa commit da44614

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,11 +2170,12 @@ def insert(arr, obj, values, axis=None):
21702170
----------
21712171
arr : array_like
21722172
Input array.
2173-
obj : {slice, int, array-like of ints}
2173+
obj : {slice, int, array-like of ints or bools}
21742174
Object that defines the index or indices before which `values` is
21752175
inserted. It supports multiple insertions when `obj` is a single
21762176
scalar or a sequence with one element (similar to calling insert
21772177
multiple times).
2178+
Boolean indices are treated as a mask of elements to insert.
21782179
values : array_like
21792180
Values to insert into `arr`. If the type of `values` is different
21802181
from that of `arr`, `values` is converted to the type of `arr`.
@@ -2266,20 +2267,12 @@ def insert(arr, obj, values, axis=None):
22662267
obj, sycl_queue=params.exec_q, usm_type=params.usm_type
22672268
)
22682269
if indices.dtype == dpnp.bool:
2269-
warnings.warn(
2270-
"In the future insert will treat boolean arrays and array-likes"
2271-
" as a boolean index instead of casting it to integers",
2272-
FutureWarning,
2273-
stacklevel=2,
2274-
)
2275-
indices = indices.astype(dpnp.intp)
2276-
# TODO: Code after warning period:
2277-
# if indices.ndim != 1:
2278-
# raise ValueError(
2279-
# "boolean array argument `obj` to insert must be "
2280-
# "one-dimensional"
2281-
# )
2282-
# indices = dpnp.nonzero(indices)[0]
2270+
if indices.ndim != 1:
2271+
raise ValueError(
2272+
"boolean array argument obj to insert "
2273+
"must be one dimensional"
2274+
)
2275+
indices = dpnp.flatnonzero(indices)
22832276
elif indices.ndim > 1:
22842277
raise ValueError(
22852278
"index array argument `obj` to insert must be one-dimensional "

dpnp/tests/test_manipulation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
get_float_dtypes,
2121
get_integer_dtypes,
2222
has_support_aspect64,
23-
numpy_version,
2423
)
2524
from .third_party.cupy import testing
2625

@@ -578,13 +577,13 @@ def test_ndarray_obj_values(self, obj, values):
578577
result = dpnp.insert(ia, obj, values)
579578
assert_equal(result, expected)
580579

581-
@pytest.mark.filterwarnings("ignore::FutureWarning")
580+
@testing.with_requires("numpy>=2.2")
582581
@pytest.mark.parametrize(
583582
"obj",
584-
[True, [False], numpy.array([True] * 4), [True, False, True, False]],
583+
[[False], numpy.array([True] * 4), [True, False, True, False]],
585584
)
586585
def test_boolean_obj(self, obj):
587-
if numpy_version() >= "2.2.0" and not isinstance(obj, numpy.ndarray):
586+
if not isinstance(obj, numpy.ndarray):
588587
# numpy.insert raises exception
589588
# TODO: remove once NumPy resolves that
590589
obj = numpy.array(obj)
@@ -593,6 +592,19 @@ def test_boolean_obj(self, obj):
593592
ia = dpnp.array(a)
594593
assert_equal(dpnp.insert(ia, obj, 9), numpy.insert(a, obj, 9))
595594

595+
@testing.with_requires("numpy>=2.2")
596+
@pytest.mark.parametrize("xp", [dpnp, numpy])
597+
@pytest.mark.parametrize(
598+
"obj_data",
599+
[True, [[True, False], [True, False]]],
600+
ids=["0d", "2d"],
601+
)
602+
def test_boolean_obj_error(self, xp, obj_data):
603+
a = xp.array([1, 2, 3])
604+
obj = xp.array(obj_data)
605+
with pytest.raises(ValueError):
606+
xp.insert(a, obj, 9)
607+
596608
def test_1D_array(self):
597609
a = numpy.array([1, 2, 3])
598610
ia = dpnp.array(a)

0 commit comments

Comments
 (0)