|
21 | 21 | import dpctl.tensor._tensor_impl as ti
|
22 | 22 | import dpctl.utils
|
23 | 23 |
|
24 |
| -from ._copy_utils import _extract_impl, _nonzero_impl |
| 24 | +from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index |
25 | 25 | from ._numpy_helper import normalize_axis_index
|
26 | 26 |
|
27 | 27 |
|
@@ -423,3 +423,77 @@ def nonzero(arr):
|
423 | 423 | if arr.ndim == 0:
|
424 | 424 | raise ValueError("Array of positive rank is expected")
|
425 | 425 | return _nonzero_impl(arr)
|
| 426 | + |
| 427 | + |
| 428 | +def _range(sh_i, i, nd, q, usm_t, dt): |
| 429 | + ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q) |
| 430 | + ind.shape = tuple(sh_i if i == j else 1 for j in range(nd)) |
| 431 | + return ind |
| 432 | + |
| 433 | + |
| 434 | +def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): |
| 435 | + """ |
| 436 | + Returns elements from an array at the one-dimensional indices specified |
| 437 | + by ``indices`` along a provided ``axis``. |
| 438 | +
|
| 439 | + Args: |
| 440 | + x (usm_ndarray): |
| 441 | + input array. Must be compatible with ``indices``, except for the |
| 442 | + axis (dimension) specified by ``axis``. |
| 443 | + indices (usm_ndarray): |
| 444 | + array indices. Must have the same rank (i.e., number of dimensions) |
| 445 | + as ``x``. |
| 446 | + axis: int |
| 447 | + axis along which to select values. If ``axis`` is negative, the |
| 448 | + function determines the axis along which to select values by |
| 449 | + counting from the last dimension. Default: ``-1``. |
| 450 | + mode (str, optional): |
| 451 | + How out-of-bounds indices will be handled. Possible values |
| 452 | + are: |
| 453 | +
|
| 454 | + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps |
| 455 | + negative indices. |
| 456 | + - ``"clip"``: clips indices to (``0 <= i < n``). |
| 457 | +
|
| 458 | + Default: ``"wrap"``. |
| 459 | +
|
| 460 | + Returns: |
| 461 | + usm_ndarray: |
| 462 | + an array having the same data type as ``x``. The returned array has |
| 463 | + the same rank (i.e., number of dimensions) as ``x`` and a shape |
| 464 | + determined according to :ref:`broadcasting`, except for the axis |
| 465 | + (dimension) specified by ``axis`` whose size must equal the size |
| 466 | + of the corresponding axis (dimension) in ``indices``. |
| 467 | +
|
| 468 | + Note: |
| 469 | + Treatment of the out-of-bound indices in ``indices`` array is controlled |
| 470 | + by the value of ``mode`` keyword. |
| 471 | + """ |
| 472 | + if not isinstance(x, dpt.usm_ndarray): |
| 473 | + raise TypeError |
| 474 | + if not isinstance(indices, dpt.usm_ndarray): |
| 475 | + raise TypeError |
| 476 | + x_nd = x.ndim |
| 477 | + if x_nd != indices.ndim: |
| 478 | + raise ValueError |
| 479 | + pp = normalize_axis_index(operator.index(axis), x_nd) |
| 480 | + out_usm_type = dpctl.utils.get_coerced_usm_type( |
| 481 | + (x.usm_type, indices.usm_type) |
| 482 | + ) |
| 483 | + exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue)) |
| 484 | + if exec_q is None: |
| 485 | + raise dpctl.utils.ExecutionPlacementError( |
| 486 | + "Execution placement can not be unambiguously inferred " |
| 487 | + "from input arguments. " |
| 488 | + ) |
| 489 | + mode_i = _get_indexing_mode(mode) |
| 490 | + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) |
| 491 | + _ind = tuple( |
| 492 | + ( |
| 493 | + indices |
| 494 | + if i == pp |
| 495 | + else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt) |
| 496 | + ) |
| 497 | + for i in range(x_nd) |
| 498 | + ) |
| 499 | + return _take_multi_index(x, _ind, 0, mode=mode_i) |
0 commit comments