|
49 | 49 | from .dpnp_algo import (
|
50 | 50 | dpnp_choose,
|
51 | 51 | dpnp_putmask,
|
52 |
| - dpnp_select, |
53 | 52 | )
|
54 | 53 | from .dpnp_array import dpnp_array
|
55 | 54 | from .dpnp_utils import (
|
56 | 55 | call_origin,
|
57 |
| - use_origin_backend, |
| 56 | + get_usm_allocations, |
58 | 57 | )
|
59 | 58 |
|
60 | 59 | __all__ = [
|
@@ -524,7 +523,7 @@ def extract(condition, a):
|
524 | 523 | :obj:`dpnp.put` : Replaces specified elements of an array with given values.
|
525 | 524 | :obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
|
526 | 525 | as necessary.
|
527 |
| - :obj:`dpnp.compress` : eturn selected slices of an array along given axis. |
| 526 | + :obj:`dpnp.compress` : Return selected slices of an array along given axis. |
528 | 527 | :obj:`dpnp.place` : Change elements of an array based on conditional and
|
529 | 528 | input values.
|
530 | 529 |
|
@@ -1344,31 +1343,125 @@ def select(condlist, choicelist, default=0):
|
1344 | 1343 |
|
1345 | 1344 | For full documentation refer to :obj:`numpy.select`.
|
1346 | 1345 |
|
1347 |
| - Limitations |
1348 |
| - ----------- |
1349 |
| - Arrays of input lists are supported as :obj:`dpnp.ndarray`. |
1350 |
| - Parameter `default` is supported only with default values. |
| 1346 | + Parameters |
| 1347 | + ---------- |
| 1348 | + condlist : list of bool dpnp.ndarray or usm_ndarray |
| 1349 | + The list of conditions which determine from which array in `choicelist` |
| 1350 | + the output elements are taken. When multiple conditions are satisfied, |
| 1351 | + the first one encountered in `condlist` is used. |
| 1352 | + choicelist : list of dpnp.ndarray or usm_ndarray |
| 1353 | + The list of arrays from which the output elements are taken. It has |
| 1354 | + to be of the same length as `condlist`. |
| 1355 | + default : {scalar, dpnp.ndarray, usm_ndarray}, optional |
| 1356 | + The element inserted in `output` when all conditions evaluate to |
| 1357 | + ``False``. Default: ``0``. |
| 1358 | +
|
| 1359 | + Returns |
| 1360 | + ------- |
| 1361 | + out : dpnp.ndarray |
| 1362 | + The output at position m is the m-th element of the array in |
| 1363 | + `choicelist` where the m-th element of the corresponding array in |
| 1364 | + `condlist` is ``True``. |
| 1365 | +
|
| 1366 | + See Also |
| 1367 | + -------- |
| 1368 | + :obj:`dpnp.where : Return elements from one of two arrays depending on |
| 1369 | + condition. |
| 1370 | + :obj:`dpnp.take` : Take elements from an array along an axis. |
| 1371 | + :obj:`dpnp.choose` : Construct an array from an index array and a set of |
| 1372 | + arrays to choose from. |
| 1373 | + :obj:`dpnp.compress` : Return selected slices of an array along given axis. |
| 1374 | + :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array. |
| 1375 | + :obj:`dpnp.diagonal` : Return specified diagonals. |
| 1376 | +
|
| 1377 | + Examples |
| 1378 | + -------- |
| 1379 | + >>> import dpnp as np |
| 1380 | +
|
| 1381 | + Beginning with an array of integers from 0 to 5 (inclusive), |
| 1382 | + elements less than ``3`` are negated, elements greater than ``3`` |
| 1383 | + are squared, and elements not meeting either of these conditions |
| 1384 | + (exactly ``3``) are replaced with a `default` value of ``42``. |
| 1385 | +
|
| 1386 | + >>> x = np.arange(6) |
| 1387 | + >>> condlist = [x<3, x>3] |
| 1388 | + >>> choicelist = [x, x**2] |
| 1389 | + >>> np.select(condlist, choicelist, 42) |
| 1390 | + array([ 0, 1, 2, 42, 16, 25]) |
| 1391 | +
|
| 1392 | + When multiple conditions are satisfied, the first one encountered in |
| 1393 | + `condlist` is used. |
| 1394 | +
|
| 1395 | + >>> condlist = [x<=4, x>3] |
| 1396 | + >>> choicelist = [x, x**2] |
| 1397 | + >>> np.select(condlist, choicelist, 55) |
| 1398 | + array([ 0, 1, 2, 3, 4, 25]) |
| 1399 | +
|
1351 | 1400 | """
|
1352 | 1401 |
|
1353 |
| - if not use_origin_backend(): |
1354 |
| - if not isinstance(condlist, list): |
1355 |
| - pass |
1356 |
| - elif not isinstance(choicelist, list): |
1357 |
| - pass |
1358 |
| - elif len(condlist) != len(choicelist): |
1359 |
| - pass |
1360 |
| - else: |
1361 |
| - val = True |
1362 |
| - size_ = condlist[0].size |
1363 |
| - for cond, choice in zip(condlist, choicelist): |
1364 |
| - if cond.size != size_ or choice.size != size_: |
1365 |
| - val = False |
1366 |
| - if not val: |
1367 |
| - pass |
1368 |
| - else: |
1369 |
| - return dpnp_select(condlist, choicelist, default).get_pyobj() |
| 1402 | + if len(condlist) != len(choicelist): |
| 1403 | + raise ValueError( |
| 1404 | + "list of cases must be same length as list of conditions" |
| 1405 | + ) |
| 1406 | + |
| 1407 | + if len(condlist) == 0: |
| 1408 | + raise ValueError("select with an empty condition list is not possible") |
| 1409 | + |
| 1410 | + dpnp.check_supported_arrays_type(*condlist) |
| 1411 | + dpnp.check_supported_arrays_type(*choicelist) |
| 1412 | + dpnp.check_supported_arrays_type( |
| 1413 | + default, scalar_type=True, all_scalars=True |
| 1414 | + ) |
| 1415 | + |
| 1416 | + if dpnp.isscalar(default): |
| 1417 | + usm_type_alloc, sycl_queue_alloc = get_usm_allocations( |
| 1418 | + condlist + choicelist |
| 1419 | + ) |
| 1420 | + dtype = dpnp.result_type(*choicelist) |
| 1421 | + default = dpnp.asarray( |
| 1422 | + default, |
| 1423 | + dtype=dtype, |
| 1424 | + usm_type=usm_type_alloc, |
| 1425 | + sycl_queue=sycl_queue_alloc, |
| 1426 | + ) |
| 1427 | + choicelist.append(default) |
| 1428 | + else: |
| 1429 | + choicelist.append(default) |
| 1430 | + usm_type_alloc, sycl_queue_alloc = get_usm_allocations( |
| 1431 | + condlist + choicelist |
| 1432 | + ) |
| 1433 | + dtype = dpnp.result_type(*choicelist) |
| 1434 | + |
| 1435 | + for i, cond in enumerate(condlist): |
| 1436 | + if cond.dtype.type is not dpnp.bool: |
| 1437 | + raise TypeError( |
| 1438 | + f"invalid entry {i} in condlist: should be boolean ndarray" |
| 1439 | + ) |
| 1440 | + |
| 1441 | + # Convert conditions to arrays and broadcast conditions and choices |
| 1442 | + # as the shape is needed for the result |
| 1443 | + condlist = dpnp.broadcast_arrays(*condlist) |
| 1444 | + choicelist = dpnp.broadcast_arrays(*choicelist) |
| 1445 | + |
| 1446 | + result_shape = dpnp.broadcast_arrays(condlist[0], choicelist[0])[0].shape |
| 1447 | + |
| 1448 | + result = dpnp.full( |
| 1449 | + result_shape, |
| 1450 | + choicelist[-1], |
| 1451 | + dtype=dtype, |
| 1452 | + usm_type=usm_type_alloc, |
| 1453 | + sycl_queue=sycl_queue_alloc, |
| 1454 | + ) |
| 1455 | + |
| 1456 | + # Use np.copyto to burn each choicelist array onto result, using the |
| 1457 | + # corresponding condlist as a boolean mask. This is done in reverse |
| 1458 | + # order since the first choice should take precedence. |
| 1459 | + choicelist = choicelist[-2::-1] |
| 1460 | + condlist = condlist[::-1] |
| 1461 | + for choice, cond in zip(choicelist, condlist): |
| 1462 | + dpnp.copyto(result, choice, where=cond) |
1370 | 1463 |
|
1371 |
| - return call_origin(numpy.select, condlist, choicelist, default) |
| 1464 | + return result |
1372 | 1465 |
|
1373 | 1466 |
|
1374 | 1467 | # pylint: disable=redefined-outer-name
|
|
0 commit comments