Skip to content

Commit c6d3a16

Browse files
authored
Merge 4e966c1 into 4550d18
2 parents 4550d18 + 4e966c1 commit c6d3a16

File tree

7 files changed

+141
-73
lines changed

7 files changed

+141
-73
lines changed

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ and the rest of the library
3838
__all__ += [
3939
"dpnp_choose",
4040
"dpnp_putmask",
41-
"dpnp_select",
4241
]
4342

4443
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRef,
@@ -102,20 +101,3 @@ cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils.
102101
for i in range(arr.size):
103102
if mask_flatiter[i]:
104103
arr_flatiter[i] = values_flatiter[i % values_size]
105-
106-
107-
cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default):
108-
cdef size_t size_ = condlist[0].size
109-
cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(condlist[0].shape, choicelist[0].dtype, None)
110-
111-
pass_val = {a: default for a in range(size_)}
112-
for i in range(len(condlist)):
113-
for j in range(size_):
114-
if (condlist[i])[j]:
115-
res_array.get_pyobj()[j] = (choicelist[i])[j]
116-
pass_val.pop(j)
117-
118-
for ind, val in pass_val.items():
119-
res_array.get_pyobj()[ind] = val
120-
121-
return res_array

dpnp/dpnp_iface_indexing.py

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@
4949
from .dpnp_algo import (
5050
dpnp_choose,
5151
dpnp_putmask,
52-
dpnp_select,
5352
)
5453
from .dpnp_array import dpnp_array
5554
from .dpnp_utils import (
5655
call_origin,
57-
use_origin_backend,
56+
get_usm_allocations,
5857
)
5958

6059
__all__ = [
@@ -524,7 +523,7 @@ def extract(condition, a):
524523
:obj:`dpnp.put` : Replaces specified elements of an array with given values.
525524
:obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
526525
as necessary.
527-
:obj:`dpnp.compress` : eturn selected slices of an array along given axis.
526+
:obj:`dpnp.compress` : Return selected slices of an array along given axis.
528527
:obj:`dpnp.place` : Change elements of an array based on conditional and
529528
input values.
530529
@@ -1344,31 +1343,125 @@ def select(condlist, choicelist, default=0):
13441343
13451344
For full documentation refer to :obj:`numpy.select`.
13461345
1347-
Limitations
1348-
-----------
1349-
Arrays of input lists are supported as :obj:`dpnp.ndarray`.
1350-
Parameter `default` is supported only with default values.
1346+
Parameters
1347+
----------
1348+
condlist : list of bool dpnp.ndarray or usm_ndarray
1349+
The list of conditions which determine from which array in `choicelist`
1350+
the output elements are taken. When multiple conditions are satisfied,
1351+
the first one encountered in `condlist` is used.
1352+
choicelist : list of dpnp.ndarray or usm_ndarray
1353+
The list of arrays from which the output elements are taken. It has
1354+
to be of the same length as `condlist`.
1355+
default : {scalar, dpnp.ndarray, usm_ndarray}, optional
1356+
The element inserted in `output` when all conditions evaluate to
1357+
``False``. Default: ``0``.
1358+
1359+
Returns
1360+
-------
1361+
out : dpnp.ndarray
1362+
The output at position m is the m-th element of the array in
1363+
`choicelist` where the m-th element of the corresponding array in
1364+
`condlist` is ``True``.
1365+
1366+
See Also
1367+
--------
1368+
:obj:`dpnp.where : Return elements from one of two arrays depending on
1369+
condition.
1370+
:obj:`dpnp.take` : Take elements from an array along an axis.
1371+
:obj:`dpnp.choose` : Construct an array from an index array and a set of
1372+
arrays to choose from.
1373+
:obj:`dpnp.compress` : Return selected slices of an array along given axis.
1374+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
1375+
:obj:`dpnp.diagonal` : Return specified diagonals.
1376+
1377+
Examples
1378+
--------
1379+
>>> import dpnp as np
1380+
1381+
Beginning with an array of integers from 0 to 5 (inclusive),
1382+
elements less than ``3`` are negated, elements greater than ``3``
1383+
are squared, and elements not meeting either of these conditions
1384+
(exactly ``3``) are replaced with a `default` value of ``42``.
1385+
1386+
>>> x = np.arange(6)
1387+
>>> condlist = [x<3, x>3]
1388+
>>> choicelist = [x, x**2]
1389+
>>> np.select(condlist, choicelist, 42)
1390+
array([ 0, 1, 2, 42, 16, 25])
1391+
1392+
When multiple conditions are satisfied, the first one encountered in
1393+
`condlist` is used.
1394+
1395+
>>> condlist = [x<=4, x>3]
1396+
>>> choicelist = [x, x**2]
1397+
>>> np.select(condlist, choicelist, 55)
1398+
array([ 0, 1, 2, 3, 4, 25])
1399+
13511400
"""
13521401

1353-
if not use_origin_backend():
1354-
if not isinstance(condlist, list):
1355-
pass
1356-
elif not isinstance(choicelist, list):
1357-
pass
1358-
elif len(condlist) != len(choicelist):
1359-
pass
1360-
else:
1361-
val = True
1362-
size_ = condlist[0].size
1363-
for cond, choice in zip(condlist, choicelist):
1364-
if cond.size != size_ or choice.size != size_:
1365-
val = False
1366-
if not val:
1367-
pass
1368-
else:
1369-
return dpnp_select(condlist, choicelist, default).get_pyobj()
1402+
if len(condlist) != len(choicelist):
1403+
raise ValueError(
1404+
"list of cases must be same length as list of conditions"
1405+
)
1406+
1407+
if len(condlist) == 0:
1408+
raise ValueError("select with an empty condition list is not possible")
1409+
1410+
dpnp.check_supported_arrays_type(*condlist)
1411+
dpnp.check_supported_arrays_type(*choicelist)
1412+
dpnp.check_supported_arrays_type(
1413+
default, scalar_type=True, all_scalars=True
1414+
)
1415+
1416+
if dpnp.isscalar(default):
1417+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations(
1418+
condlist + choicelist
1419+
)
1420+
dtype = dpnp.result_type(*choicelist)
1421+
default = dpnp.asarray(
1422+
default,
1423+
dtype=dtype,
1424+
usm_type=usm_type_alloc,
1425+
sycl_queue=sycl_queue_alloc,
1426+
)
1427+
choicelist.append(default)
1428+
else:
1429+
choicelist.append(default)
1430+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations(
1431+
condlist + choicelist
1432+
)
1433+
dtype = dpnp.result_type(*choicelist)
1434+
1435+
for i, cond in enumerate(condlist):
1436+
if cond.dtype.type is not dpnp.bool:
1437+
raise TypeError(
1438+
f"invalid entry {i} in condlist: should be boolean ndarray"
1439+
)
1440+
1441+
# Convert conditions to arrays and broadcast conditions and choices
1442+
# as the shape is needed for the result
1443+
condlist = dpnp.broadcast_arrays(*condlist)
1444+
choicelist = dpnp.broadcast_arrays(*choicelist)
1445+
1446+
result_shape = dpnp.broadcast_arrays(condlist[0], choicelist[0])[0].shape
1447+
1448+
result = dpnp.full(
1449+
result_shape,
1450+
choicelist[-1],
1451+
dtype=dtype,
1452+
usm_type=usm_type_alloc,
1453+
sycl_queue=sycl_queue_alloc,
1454+
)
1455+
1456+
# Use np.copyto to burn each choicelist array onto result, using the
1457+
# corresponding condlist as a boolean mask. This is done in reverse
1458+
# order since the first choice should take precedence.
1459+
choicelist = choicelist[-2::-1]
1460+
condlist = condlist[::-1]
1461+
for choice, cond in zip(choicelist, condlist):
1462+
dpnp.copyto(result, choice, where=cond)
13701463

1371-
return call_origin(numpy.select, condlist, choicelist, default)
1464+
return result
13721465

13731466

13741467
# pylint: disable=redefined-outer-name

tests/skipped_tests.tbl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
116116
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
117117
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
118118

119-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
120-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
121-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
122-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex
123-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default
124-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex
125-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar
126-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists
127-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
128-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
129-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex
130-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
131-
132119
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
133120
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises
134121

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
169169
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
170170
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
171171

172-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
173-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
174-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
175-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_complex
176-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default
177-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_complex
178-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_default_scalar
179-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_empty_lists
180-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
181-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
182-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex
183-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
184-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist
185-
186172
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
187173
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises
188174

tests/test_sycl_queue.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,6 +2338,19 @@ def test_astype(device_x, device_y):
23382338
assert_sycl_queue_equal(y.sycl_queue, sycl_queue)
23392339

23402340

2341+
@pytest.mark.parametrize(
2342+
"device",
2343+
valid_devices,
2344+
ids=[device.filter_string for device in valid_devices],
2345+
)
2346+
def test_select(device):
2347+
sycl_queue = dpctl.SyclQueue(device)
2348+
condlist = [dpnp.array([True, False], sycl_queue=sycl_queue)]
2349+
choicelist = [dpnp.array([1, 2], sycl_queue=sycl_queue)]
2350+
res = dpnp.select(condlist, choicelist)
2351+
assert_sycl_queue_equal(res.sycl_queue, sycl_queue)
2352+
2353+
23412354
@pytest.mark.parametrize("copy", [True, False], ids=["True", "False"])
23422355
@pytest.mark.parametrize(
23432356
"device",

tests/test_usm_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,6 +1363,14 @@ def test_histogram_bin_edges(usm_type_v, usm_type_w):
13631363
assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w])
13641364

13651365

1366+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
1367+
def test_select(usm_type):
1368+
condlist = [dp.array([True, False], usm_type=usm_type)]
1369+
choicelist = [dp.array([1, 2], usm_type=usm_type)]
1370+
res = dp.select(condlist, choicelist)
1371+
assert res.usm_type == usm_type
1372+
1373+
13661374
@pytest.mark.parametrize("copy", [True, False], ids=["True", "False"])
13671375
@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
13681376
def test_nan_to_num(copy, usm_type_a):

tests/third_party/cupy/indexing_tests/test_indexing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,9 @@ def test_select_type_error_condlist(self, dtype):
355355
a = cupy.arange(10, dtype=dtype)
356356
condlist = [[3] * 10, [2] * 10]
357357
choicelist = [a, a**2]
358-
with pytest.raises(AttributeError):
358+
with pytest.raises(TypeError):
359359
cupy.select(condlist, choicelist)
360360

361-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
362361
@testing.for_all_dtypes(no_bool=True)
363362
def test_select_type_error_choicelist(self, dtype):
364363
a, b = list(range(10)), list(range(-10, 0))
@@ -388,7 +387,7 @@ def test_select_default_scalar(self, dtype):
388387
b = cupy.arange(20)
389388
condlist = [a < 3, b > 8]
390389
choicelist = [a, b]
391-
with pytest.raises(TypeError):
390+
with pytest.raises(ValueError):
392391
cupy.select(condlist, choicelist, [dtype(2)])
393392

394393
@pytest.mark.skip("as_strided() is not implemented yet")

0 commit comments

Comments
 (0)